package com.alibaba.alink.operator.batch.feature;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.feature.TargetEncoderConverter;
import com.alibaba.alink.operator.common.slidingwindow.SessionSharedData;
import com.alibaba.alink.params.feature.TargetEncoderTrainParams;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.MapPartitionOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "selectedCols"), @ParamSelectColumnSpec(name = "labelCol")})
@NameCn("TargetEncoder")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.TargetEncoder")
@NameEn("TargetEncoder")
/* loaded from: input_file:com/alibaba/alink/operator/batch/feature/TargetEncoderTrainBatchOp.class */
public class TargetEncoderTrainBatchOp extends BatchOperator<TargetEncoderTrainBatchOp> implements TargetEncoderTrainParams<TargetEncoderTrainBatchOp> {

    /* loaded from: input_file:com/alibaba/alink/operator/batch/feature/TargetEncoderTrainBatchOp$BuildGroupByCol.class */
    public static class BuildGroupByCol extends RichMapPartitionFunction<Tuple2<Integer, Row>, Tuple2<Integer, Row>> {
        int superStepNumber;
        int lastIndex;
        int[] selectedColIndices;

        BuildGroupByCol(int i, int[] iArr) {
            this.lastIndex = i;
            this.selectedColIndices = iArr;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            this.superStepNumber = getIterationRuntimeContext().getSuperstepNumber() - 1;
        }

        public void mapPartition(Iterable<Tuple2<Integer, Row>> iterable, Collector<Tuple2<Integer, Row>> collector) throws Exception {
            int i = this.selectedColIndices[this.superStepNumber];
            for (Tuple2<Integer, Row> tuple2 : iterable) {
                ((Row) tuple2.f1).setField(this.lastIndex, ((Row) tuple2.f1).getField(i));
                collector.collect(tuple2);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/feature/TargetEncoderTrainBatchOp$BuildIterRes.class */
    public static class BuildIterRes extends RichMapPartitionFunction<Tuple2<Integer, Row>, Tuple2<Integer, Row>> {
        Row items;

        private BuildIterRes() {
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            this.items = (Row) getRuntimeContext().getBroadcastVariable("rowMeans").get(0);
        }

        public void mapPartition(Iterable<Tuple2<Integer, Row>> iterable, Collector<Tuple2<Integer, Row>> collector) throws Exception {
            int superstepNumber = getIterationRuntimeContext().getSuperstepNumber() - 1;
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            if (superstepNumber % getRuntimeContext().getMaxNumberOfParallelSubtasks() == indexOfThisSubtask) {
                SessionSharedData.put("" + superstepNumber, indexOfThisSubtask, this.items);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/feature/TargetEncoderTrainBatchOp$BuildModelData.class */
    public static class BuildModelData extends RichMapPartitionFunction<Tuple2<Integer, Row>, Tuple2<Integer, Row>> {
        int iterNum;

        BuildModelData(int i) {
            this.iterNum = i;
        }

        public void mapPartition(Iterable<Tuple2<Integer, Row>> iterable, Collector<Tuple2<Integer, Row>> collector) throws Exception {
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            int maxNumberOfParallelSubtasks = getRuntimeContext().getMaxNumberOfParallelSubtasks();
            for (int i = indexOfThisSubtask; i < this.iterNum; i += maxNumberOfParallelSubtasks) {
                collector.collect(Tuple2.of(0, (Row) SessionSharedData.get("" + i, indexOfThisSubtask)));
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/feature/TargetEncoderTrainBatchOp$CalcMean.class */
    public static class CalcMean implements GroupCombineFunction<Tuple2<Integer, Row>, Tuple2<Object, Double>> {
        int groupIndex;
        int labelIndex;
        String positiveLabel;

        CalcMean(int i, int i2, String str) {
            this.groupIndex = i;
            this.labelIndex = i2;
            this.positiveLabel = str;
        }

        public void combine(Iterable<Tuple2<Integer, Row>> iterable, Collector<Tuple2<Object, Double>> collector) throws Exception {
            int i = 0;
            double d = 0.0d;
            Object obj = null;
            for (Tuple2<Integer, Row> tuple2 : iterable) {
                obj = ((Row) tuple2.f1).getField(this.groupIndex);
                i++;
                if (this.positiveLabel == null) {
                    d += ((Double) ((Row) tuple2.f1).getField(this.labelIndex)).doubleValue();
                } else if (((Row) tuple2.f1).getField(this.labelIndex).toString().equals(this.positiveLabel)) {
                    d += 1.0d;
                }
            }
            collector.collect(Tuple2.of(obj, Double.valueOf(d / i)));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/feature/TargetEncoderTrainBatchOp$MapInputData.class */
    public static class MapInputData implements MapPartitionFunction<Row, Row> {
        int selectedColSize;
        int originColSize;

        MapInputData(int i, int i2) {
            this.selectedColSize = i;
            this.originColSize = i2;
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Row> collector) throws Exception {
            Row row = new Row(this.selectedColSize + this.originColSize + 1);
            for (Row row2 : iterable) {
                for (int i = 0; i < this.originColSize; i++) {
                    row.setField(i, row2.getField(i));
                }
                collector.collect(row);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/feature/TargetEncoderTrainBatchOp$ReduceColumnInfo.class */
    public static class ReduceColumnInfo extends RichGroupReduceFunction<Tuple2<Object, Double>, Tuple2<String, HashMap<String, Double>>> {
        String[] selectedCols;

        ReduceColumnInfo(String[] strArr) {
            this.selectedCols = strArr;
        }

        public void reduce(Iterable<Tuple2<Object, Double>> iterable, Collector<Tuple2<String, HashMap<String, Double>>> collector) throws Exception {
            HashMap hashMap = new HashMap();
            for (Tuple2<Object, Double> tuple2 : iterable) {
                hashMap.put(tuple2.f0.toString(), tuple2.f1);
            }
            collector.collect(Tuple2.of(this.selectedCols[getIterationRuntimeContext().getSuperstepNumber() - 1], hashMap));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/feature/TargetEncoderTrainBatchOp$ReduceModelData.class */
    public static class ReduceModelData implements GroupReduceFunction<Tuple2<Integer, Row>, Row> {
        private ReduceModelData() {
        }

        public void reduce(Iterable<Tuple2<Integer, Row>> iterable, Collector<Row> collector) throws Exception {
            TargetEncoderConverter targetEncoderConverter = new TargetEncoderConverter();
            Iterator<Tuple2<Integer, Row>> it = iterable.iterator();
            while (it.hasNext()) {
                Row row = (Row) it.next().f1;
                row.setField(1, JsonConverter.toJson(row.getField(1)));
                targetEncoderConverter.save2(row, collector);
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/feature/TargetEncoderTrainBatchOp$RowKeySelector.class */
    public static class RowKeySelector implements KeySelector<Tuple2<Integer, Row>, Comparable> {
        int index;

        public RowKeySelector(int i) {
            this.index = i;
        }

        public Comparable getKey(Tuple2<Integer, Row> tuple2) {
            return (Comparable) ((Row) tuple2.f1).getField(this.index);
        }
    }

    public TargetEncoderTrainBatchOp(Params params) {
        super(params);
    }

    public TargetEncoderTrainBatchOp() {
        this(new Params());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public TargetEncoderTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String labelCol = getLabelCol();
        String[] selectedCols = getSelectedCols();
        if (selectedCols == null) {
            String[] categoricalCols = TableUtil.getCategoricalCols(checkAndGetFirst.getSchema(), checkAndGetFirst.getColNames(), null);
            ArrayList arrayList = new ArrayList();
            for (String str : categoricalCols) {
                if (!str.equals(labelCol)) {
                    arrayList.add(str);
                }
            }
            selectedCols = (String[]) arrayList.toArray(new String[0]);
        }
        int[] findColIndices = TableUtil.findColIndices(checkAndGetFirst.getSchema(), selectedCols);
        int findColIndex = TableUtil.findColIndex(checkAndGetFirst.getSchema(), labelCol);
        String positiveLabelValueString = getPositiveLabelValueString();
        int length = checkAndGetFirst.getColNames().length;
        int length2 = findColIndices.length;
        int i = length + length2;
        MapOperator map = checkAndGetFirst.getDataSet().mapPartition(new MapInputData(length2, length)).map(new MapFunction<Row, Tuple2<Integer, Row>>() { // from class: com.alibaba.alink.operator.batch.feature.TargetEncoderTrainBatchOp.1
            public Tuple2<Integer, Row> map(Row row) throws Exception {
                return Tuple2.of(0, row);
            }
        });
        DataSource fromElements = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromElements(new Tuple2[]{Tuple2.of(0, new Row(0))});
        DeltaIteration iterateDelta = fromElements.iterateDelta(map, length2, new int[]{0});
        MapPartitionOperator mapPartition = iterateDelta.getWorkset().mapPartition(new BuildGroupByCol(i, findColIndices));
        setOutputTable(DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) iterateDelta.closeWith(fromElements.mapPartition(new BuildIterRes()).withBroadcastSet(mapPartition.groupBy(new RowKeySelector(i)).combineGroup(new CalcMean(i, findColIndex, positiveLabelValueString)).reduceGroup(new ReduceColumnInfo(selectedCols)).map(new MapFunction<Tuple2<String, HashMap<String, Double>>, Row>() { // from class: com.alibaba.alink.operator.batch.feature.TargetEncoderTrainBatchOp.2
            public Row map(Tuple2<String, HashMap<String, Double>> tuple2) throws Exception {
                Row row = new Row(2);
                row.setField(0, tuple2.f0);
                row.setField(1, tuple2.f1);
                return row;
            }
        }).returns(TypeInformation.of(Row.class)), "rowMeans"), mapPartition).mapPartition(new BuildModelData(length2)).reduceGroup(new ReduceModelData()), new TargetEncoderConverter(selectedCols).getModelSchema()));
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ TargetEncoderTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
