package com.alibaba.alink.operator.batch.feature;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamsIgnoredOnWebUI;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.SelectedColsWithFirstInputSpec;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp;
import com.alibaba.alink.operator.common.dataproc.StringIndexerUtil;
import com.alibaba.alink.operator.common.feature.OneHotModelDataConverter;
import com.alibaba.alink.operator.common.feature.OneHotModelInfo;
import com.alibaba.alink.operator.common.feature.OneHotModelMapper;
import com.alibaba.alink.operator.common.feature.binning.BinDivideType;
import com.alibaba.alink.operator.common.feature.binning.Bins;
import com.alibaba.alink.operator.common.feature.binning.FeatureBinsCalculator;
import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter;
import com.alibaba.alink.params.dataproc.HasSelectedColTypes;
import com.alibaba.alink.params.feature.HasEnableElse;
import com.alibaba.alink.params.feature.OneHotTrainParams;
import com.alibaba.alink.params.shared.colname.HasSelectedCols;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.aggregation.Aggregations;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.operators.ProjectOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@SelectedColsWithFirstInputSpec
@InputPorts(values = {@PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH)})
@ParamsIgnoredOnWebUI(names = {"discreteThresholdsArray"})
@NameCn("独热编码训练")
@NameEn("OneHot Encoder Train")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.OneHotEncoder")
/* loaded from: input_file:com/alibaba/alink/operator/batch/feature/OneHotTrainBatchOp.class */
public final class OneHotTrainBatchOp extends BatchOperator<OneHotTrainBatchOp> implements OneHotTrainParams<OneHotTrainBatchOp>, WithModelInfoBatchOp<OneHotModelInfo, OneHotTrainBatchOp, OneHotModelInfoBatchOp> {
    private static final long serialVersionUID = -4869233204093489524L;

    public OneHotTrainBatchOp() {
        super(null);
    }

    public OneHotTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public OneHotTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        int[] array;
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        final String[] selectedCols = getSelectedCols();
        final String[] strArr = new String[selectedCols.length];
        for (int i = 0; i < selectedCols.length; i++) {
            strArr[i] = FlinkTypeConverter.getTypeString(TableUtil.findColTypeWithAssertAndHint(checkAndGetFirst.getSchema(), selectedCols[i]));
        }
        if (getParams().contains(OneHotTrainParams.DISCRETE_THRESHOLDS_ARRAY)) {
            array = Arrays.stream(getDiscreteThresholdsArray()).mapToInt((v0) -> {
                return v0.intValue();
            }).toArray();
        } else {
            array = new int[selectedCols.length];
            Arrays.fill(array, getDiscreteThresholds().intValue());
        }
        final boolean isEnableElse = OneHotModelMapper.isEnableElse(array);
        final int[] iArr = array;
        ProjectOperator project = StringIndexerUtil.zipWithIndexPerColumn(StringIndexerUtil.countTokens(checkAndGetFirst.select(selectedCols).getDataSet(), true).filter(new FilterFunction<Tuple3<Integer, String, Long>>() { // from class: com.alibaba.alink.operator.batch.feature.OneHotTrainBatchOp.1
            private static final long serialVersionUID = -8219708805787440332L;

            public boolean filter(Tuple3<Integer, String, Long> tuple3) {
                return ((Long) tuple3.f2).longValue() >= ((long) iArr[((Integer) tuple3.f0).intValue()]);
            }
        }).project(new int[]{0, 1})).project(new int[]{1, 2, 0});
        Operator name = project.mapPartition(new RichMapPartitionFunction<Tuple3<Integer, String, Long>, Row>() { // from class: com.alibaba.alink.operator.batch.feature.OneHotTrainBatchOp.2
            private static final long serialVersionUID = 1218440919919078839L;

            public void mapPartition(Iterable<Tuple3<Integer, String, Long>> iterable, Collector<Row> collector) throws Exception {
                Params params = null;
                if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                    params = new Params().set((ParamInfo<ParamInfo<String[]>>) HasSelectedCols.SELECTED_COLS, (ParamInfo<String[]>) selectedCols).set((ParamInfo<ParamInfo<String[]>>) HasSelectedColTypes.SELECTED_COL_TYPES, (ParamInfo<String[]>) strArr).set((ParamInfo<ParamInfo<Boolean>>) HasEnableElse.ENABLE_ELSE, (ParamInfo<Boolean>) Boolean.valueOf(isEnableElse));
                }
                new OneHotModelDataConverter().save2(Tuple2.of(params, iterable), collector);
            }
        }).name("build_model");
        MapOperator map = project.groupBy(new int[]{0}).aggregate(Aggregations.MAX, 2).map(new MapFunction<Tuple3<Integer, String, Long>, Row>() { // from class: com.alibaba.alink.operator.batch.feature.OneHotTrainBatchOp.3
            private static final long serialVersionUID = 8508091938066560805L;

            public Row map(Tuple3<Integer, String, Long> tuple3) throws Exception {
                return Row.of(new Object[]{selectedCols[((Integer) tuple3.f0).intValue()], Long.valueOf(((Long) tuple3.f2).longValue() + 1)});
            }
        });
        setOutput((DataSet<Row>) name, new OneHotModelDataConverter().getModelSchema());
        setSideOutputTables(new Table[]{DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) map, new String[]{"selectedCol", "distinctTokenNumber"}, (TypeInformation<?>[]) new TypeInformation[]{Types.STRING, Types.LONG})});
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp
    public OneHotModelInfoBatchOp getModelInfoBatchOp() {
        return new OneHotModelInfoBatchOp(getParams()).linkFrom(this);
    }

    public static DataSet<FeatureBinsCalculator> transformModelToFeatureBins(DataSet<Row> dataSet) {
        return dataSet.filter(new FilterFunction<Row>() { // from class: com.alibaba.alink.operator.batch.feature.OneHotTrainBatchOp.5
            private static final long serialVersionUID = 6390346962976409046L;

            public boolean filter(Row row) {
                return ((Long) row.getField(0)).longValue() < 0;
            }
        }).flatMap(new FlatMapFunction<Row, Tuple2<Long, FeatureBinsCalculator>>() { // from class: com.alibaba.alink.operator.batch.feature.OneHotTrainBatchOp.4
            private static final long serialVersionUID = -6071069156554419492L;

            public void flatMap(Row row, Collector<Tuple2<Long, FeatureBinsCalculator>> collector) throws Exception {
                Params fromJson = Params.fromJson((String) row.getField(1));
                String[] strArr = (String[]) fromJson.get(HasSelectedCols.SELECTED_COLS);
                String[] strArr2 = (String[]) fromJson.get(HasSelectedColTypes.SELECTED_COL_TYPES);
                for (int i = 0; i < strArr.length; i++) {
                    collector.collect(Tuple2.of(Long.valueOf(i), FeatureBinsCalculator.createDiscreteCalculator(BinDivideType.DISCRETE, strArr[i], FlinkTypeConverter.getFlinkType(strArr2[i]), new Bins())));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Tuple2<Long, FeatureBinsCalculator>>) collector);
            }
        }).leftOuterJoin(dataSet.groupBy(new int[]{0}).reduceGroup(new GroupReduceFunction<Row, Tuple2<Long, List<Row>>>() { // from class: com.alibaba.alink.operator.batch.feature.OneHotTrainBatchOp.6
            private static final long serialVersionUID = -6261110336454767751L;

            public void reduce(Iterable<Row> iterable, Collector<Tuple2<Long, List<Row>>> collector) throws Exception {
                long j = -1;
                ArrayList arrayList = new ArrayList();
                for (Row row : iterable) {
                    j = ((Long) row.getField(0)).longValue();
                    arrayList.add(row);
                }
                if (j >= 0) {
                    collector.collect(Tuple2.of(Long.valueOf(j), arrayList));
                }
            }
        })).where(new int[]{0}).equalTo(new int[]{0}).with(new JoinFunction<Tuple2<Long, FeatureBinsCalculator>, Tuple2<Long, List<Row>>, FeatureBinsCalculator>() { // from class: com.alibaba.alink.operator.batch.feature.OneHotTrainBatchOp.7
            private static final long serialVersionUID = -301364692111181603L;

            public FeatureBinsCalculator join(Tuple2<Long, FeatureBinsCalculator> tuple2, Tuple2<Long, List<Row>> tuple22) {
                FeatureBinsCalculator featureBinsCalculator = (FeatureBinsCalculator) tuple2.f1;
                if (tuple22 != null) {
                    for (Row row : (List) tuple22.f1) {
                        if (null == featureBinsCalculator.bin.normBins) {
                            featureBinsCalculator.bin.normBins = new ArrayList();
                        }
                        featureBinsCalculator.bin.normBins.add(new Bins.BaseBin(Long.valueOf(((Long) row.getField(2)).longValue()), (String) row.getField(1)));
                    }
                }
                return featureBinsCalculator;
            }
        });
    }

    public static DataSet<Row> transformFeatureBinsToModel(DataSet<FeatureBinsCalculator> dataSet) {
        DataSet zipWithIndex = DataSetUtils.zipWithIndex(dataSet);
        return zipWithIndex.mapPartition(new RichMapPartitionFunction<Tuple2<Long, FeatureBinsCalculator>, Row>() { // from class: com.alibaba.alink.operator.batch.feature.OneHotTrainBatchOp.9
            private static final long serialVersionUID = -2674278791699432005L;

            public void mapPartition(Iterable<Tuple2<Long, FeatureBinsCalculator>> iterable, Collector<Row> collector) throws Exception {
                Params params = null;
                if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                    List broadcastVariable = getRuntimeContext().getBroadcastVariable("selectedCols");
                    String[] strArr = new String[broadcastVariable.size()];
                    String[] strArr2 = new String[broadcastVariable.size()];
                    broadcastVariable.forEach(tuple3 -> {
                        strArr[((Long) tuple3.f0).intValue()] = (String) tuple3.f1;
                        strArr2[((Long) tuple3.f0).intValue()] = (String) tuple3.f2;
                    });
                    params = new Params().set((ParamInfo<ParamInfo<String[]>>) HasSelectedCols.SELECTED_COLS, (ParamInfo<String[]>) strArr).set((ParamInfo<ParamInfo<String[]>>) HasSelectedColTypes.SELECTED_COL_TYPES, (ParamInfo<String[]>) strArr2).set((ParamInfo<ParamInfo<Boolean>>) HasEnableElse.ENABLE_ELSE, (ParamInfo<Boolean>) true);
                }
                OneHotTrainBatchOp.transformFeatureBinsToModel(iterable, collector, params);
            }
        }).withBroadcastSet(zipWithIndex.map(new MapFunction<Tuple2<Long, FeatureBinsCalculator>, Tuple3<Long, String, String>>() { // from class: com.alibaba.alink.operator.batch.feature.OneHotTrainBatchOp.8
            private static final long serialVersionUID = 888997785087051091L;

            public Tuple3<Long, String, String> map(Tuple2<Long, FeatureBinsCalculator> tuple2) throws Exception {
                return Tuple3.of(tuple2.f0, ((FeatureBinsCalculator) tuple2.f1).getFeatureName(), ((FeatureBinsCalculator) tuple2.f1).getFeatureType());
            }
        }), "selectedCols");
    }

    public static void transformFeatureBinsToModel(Iterable<Tuple2<Long, FeatureBinsCalculator>> iterable, Collector<Row> collector, Params params) {
        ArrayList arrayList = new ArrayList();
        for (Tuple2<Long, FeatureBinsCalculator> tuple2 : iterable) {
            for (Bins.BaseBin baseBin : ((FeatureBinsCalculator) tuple2.f1).bin.normBins) {
                Iterator<String> it = baseBin.getValues().iterator();
                while (it.hasNext()) {
                    arrayList.add(Tuple3.of(Integer.valueOf(((Long) tuple2.f0).intValue()), it.next(), baseBin.getIndex()));
                }
            }
        }
        new OneHotModelDataConverter().save2(Tuple2.of(params, arrayList), collector);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ OneHotTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
