package com.alibaba.alink.operator.common.dataproc;

import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.mapper.FlatMapper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.outlier.TimeSeriesAnomsUtils;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.dataproc.StratifiedSampleParams;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/StratifiedSampleMapper.class */
public class StratifiedSampleMapper extends FlatMapper {
    private static final long serialVersionUID = -3276484935413372979L;
    private double sampleRatio;
    private Map<String, Double> sampleRatios;
    private int strataColIdx;

    public StratifiedSampleMapper(TableSchema tableSchema, Params params) {
        super(tableSchema, params);
        String str = (String) this.params.get(StratifiedSampleParams.STRATA_COL);
        this.sampleRatio = ((Double) this.params.get(StratifiedSampleParams.STRATA_RATIO)).doubleValue();
        String str2 = (String) this.params.get(StratifiedSampleParams.STRATA_RATIOS);
        this.sampleRatios = new HashMap();
        for (String str3 : str2.split(",")) {
            String[] split = str3.split(TimeSeriesAnomsUtils.VAL_DELIMITER);
            AkPreconditions.checkArgument(split.length == 2, "Invalid format for param ratios.");
            double parseDouble = Double.parseDouble(split[1]);
            AkPreconditions.checkArgument(parseDouble >= Criteria.INVALID_GAIN && parseDouble <= 1.0d, "Param ratios must be in range [0, 1].");
            this.sampleRatios.put(split[0], Double.valueOf(parseDouble));
        }
        this.strataColIdx = TableUtil.findColIndexWithAssertAndHint(tableSchema.getFieldNames(), str);
    }

    @Override // com.alibaba.alink.common.mapper.FlatMapper
    public TableSchema getOutputSchema() {
        return getDataSchema();
    }

    @Override // com.alibaba.alink.common.mapper.FlatMapper
    public void flatMap(Row row, Collector<Row> collector) throws Exception {
        double d = this.sampleRatio;
        String valueOf = String.valueOf(row.getField(this.strataColIdx));
        if (this.sampleRatios.containsKey(valueOf)) {
            d = this.sampleRatios.get(valueOf).doubleValue();
        } else if (d < Criteria.INVALID_GAIN || d > 1.0d) {
            throw new AkIllegalArgumentException("Illegal ratio  for [" + valueOf + "]. Please set proper values for ratio or ratios.");
        }
        if (Math.random() < d) {
            collector.collect(row);
        }
    }
}
