package com.alibaba.alink.operator.batch.feature;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.batch.feature.BaseCrossTrainBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.feature.AutoCross.BuildSideOutput;
import com.alibaba.alink.operator.common.feature.AutoCross.DataProfile;
import com.alibaba.alink.operator.common.feature.AutoCross.FeatureEvaluator;
import com.alibaba.alink.operator.common.feature.AutoCross.FeatureSet;
import com.alibaba.alink.operator.common.linear.LinearModelData;
import com.alibaba.alink.operator.common.linear.LinearModelType;
import com.alibaba.alink.operator.common.slidingwindow.SessionSharedData;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.util.ArrayList;
import java.util.List;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL), @PortSpec(value = PortType.DATA, desc = PortDesc.CROSSED_FEATURES)})
@NameCn("AutoCross训练")
@NameEn("AutoCross Train")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.AutoCross")
/* loaded from: input_file:com/alibaba/alink/operator/batch/feature/AutoCrossTrainBatchOp.class */
public class AutoCrossTrainBatchOp extends BaseCrossTrainBatchOp<AutoCrossTrainBatchOp> {
    private static final long serialVersionUID = 2847616118502942858L;
    public static final String AC_TRAIN_DATA = "AC_TRAIN_DATA";
    public static final int SESSION_ID = SessionSharedData.getNewSessionId();

    /* loaded from: input_file:com/alibaba/alink/operator/batch/feature/AutoCrossTrainBatchOp$BuildModel.class */
    private static class BuildModel extends RichFlatMapFunction<FeatureSet, Row> {
        private static final long serialVersionUID = -3939236593612638919L;
        private int[] indexSize;
        private String vectorCol;
        private boolean hasDiscrete;
        private String[] numericalCols;

        BuildModel(String str, String[] strArr, boolean z) {
            this.vectorCol = str;
            this.hasDiscrete = z;
            this.numericalCols = strArr;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            this.indexSize = (int[]) getRuntimeContext().getBroadcastVariable("featureSize").get(0);
        }

        public void flatMap(FeatureSet featureSet, Collector<Row> collector) throws Exception {
            featureSet.numericalCols = this.numericalCols;
            featureSet.indexSize = this.indexSize;
            featureSet.vecColName = this.vectorCol;
            featureSet.hasDiscrete = this.hasDiscrete;
            collector.collect(Row.of(new Object[]{0L, featureSet.toString(), null}));
            for (int i = 0; i < featureSet.crossFeatureSet.size(); i++) {
                collector.collect(Row.of(new Object[]{Long.valueOf(i + 1), JsonConverter.toJson(featureSet.crossFeatureSet.get(i)), featureSet.scores.get(i)}));
            }
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((FeatureSet) obj, (Collector<Row>) collector);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/feature/AutoCrossTrainBatchOp$CrossFeatureOperation.class */
    private static class CrossFeatureOperation extends RichMapPartitionFunction<Tuple3<Double, Double, Vector>, Tuple3<int[], Double, double[]>> {
        private static final long serialVersionUID = -4682615150965402842L;
        transient FeatureSet featureSet;
        transient int numTasks;
        transient List<int[]> candidates;
        private int[] featureSize;
        int numericalSize;
        private final LinearModelType linearModelType;
        private final double fraction;
        private final boolean toFixCoef;
        private final int kCross;

        CrossFeatureOperation(int i, LinearModelType linearModelType, double d, boolean z, int i2) {
            this.numericalSize = i;
            this.linearModelType = linearModelType;
            this.fraction = d;
            this.toFixCoef = z;
            this.kCross = i2;
        }

        public void open(Configuration configuration) throws Exception {
            this.numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
            this.featureSet = (FeatureSet) getRuntimeContext().getBroadcastVariable("featureSet").get(0);
            this.featureSize = (int[]) getIterationRuntimeContext().getBroadcastVariable("featureSize").get(0);
            this.candidates = this.featureSet.generateCandidateCrossFeatures();
            if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println(String.format("\n** step %d, # picked features %d, # total candidates %d", Integer.valueOf(getIterationRuntimeContext().getSuperstepNumber()), Integer.valueOf(this.featureSet.crossFeatureSet.size()), Integer.valueOf(this.candidates.size())));
            }
        }

        public void mapPartition(Iterable<Tuple3<Double, Double, Vector>> iterable, Collector<Tuple3<int[], Double, double[]>> collector) throws Exception {
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            if (this.candidates.size() <= indexOfThisSubtask) {
                return;
            }
            List list = (List) SessionSharedData.get("AC_TRAIN_DATA", AutoCrossTrainBatchOp.SESSION_ID, indexOfThisSubtask);
            if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println("taskId: " + indexOfThisSubtask + ", data size: " + list.size());
            }
            FeatureEvaluator featureEvaluator = new FeatureEvaluator(this.linearModelType, list, this.featureSize, this.featureSet.getFixedCoefs(), this.fraction, this.toFixCoef, this.kCross);
            int i = indexOfThisSubtask;
            while (true) {
                int i2 = i;
                if (i2 >= this.candidates.size()) {
                    return;
                }
                int[] iArr = this.candidates.get(i2);
                ArrayList arrayList = new ArrayList(this.featureSet.crossFeatureSet);
                arrayList.add(iArr);
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.println("evaluating " + JsonConverter.toJson(arrayList));
                }
                Tuple2<Double, double[]> score = featureEvaluator.score(arrayList, this.numericalSize);
                collector.collect(Tuple3.of(iArr, Double.valueOf(((Double) score.f0).doubleValue()), score.f1));
                i = i2 + this.numTasks;
            }
        }
    }

    public AutoCrossTrainBatchOp() {
        this(new Params());
    }

    public AutoCrossTrainBatchOp(Params params) {
        super(params);
    }

    @Override // com.alibaba.alink.operator.batch.feature.BaseCrossTrainBatchOp
    DataSet<Row> buildAcModelData(DataSet<Tuple3<Double, Double, Vector>> dataSet, DataSet<int[]> dataSet2, BaseCrossTrainBatchOp.DataColumnsSaver dataColumnsSaver) {
        String[] strArr = dataColumnsSaver.numericalCols;
        final double doubleValue = getFraction().doubleValue();
        int intValue = getKCross().intValue();
        final LinearModelType linearModelType = LinearModelType.LR;
        IterativeDataSet iterate = dataSet2.map(new RichMapFunction<int[], FeatureSet>() { // from class: com.alibaba.alink.operator.batch.feature.AutoCrossTrainBatchOp.2
            Tuple2<double[], Double> initialCoef;
            private static final long serialVersionUID = 2539121870837769846L;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                this.initialCoef = (Tuple2) getRuntimeContext().getBroadcastVariable("initialCoefsAndScore").get(0);
            }

            public FeatureSet map(int[] iArr) throws Exception {
                FeatureSet featureSet = new FeatureSet(iArr);
                featureSet.updateFixedCoefs((double[]) this.initialCoef.f0);
                return featureSet;
            }
        }).withBroadcastSet(dataSet.mapPartition(new RichMapPartitionFunction<Tuple3<Double, Double, Vector>, Tuple2<double[], Double>>() { // from class: com.alibaba.alink.operator.batch.feature.AutoCrossTrainBatchOp.1
            public void mapPartition(Iterable<Tuple3<Double, Double, Vector>> iterable, Collector<Tuple2<double[], Double>> collector) throws Exception {
                int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                if (indexOfThisSubtask != 0) {
                    return;
                }
                Tuple2<List<Tuple3<Double, Double, Vector>>, List<Tuple3<Double, Double, Vector>>> split = FeatureEvaluator.split((List) SessionSharedData.get("AC_TRAIN_DATA", AutoCrossTrainBatchOp.SESSION_ID, indexOfThisSubtask), doubleValue, 0);
                LinearModelData train = FeatureEvaluator.train((List) split.f0, new DataProfile(linearModelType, true));
                collector.collect(Tuple2.of(train.coefVector.getData(), Double.valueOf(FeatureEvaluator.evaluate(train, (List) split.f1))));
            }
        }).withBroadcastSet(dataSet, "barrier"), "initialCoefsAndScore").iterate(getMaxSearchStep().intValue());
        Operator name = dataSet.mapPartition(new CrossFeatureOperation(strArr.length, linearModelType, doubleValue, true, intValue)).withBroadcastSet(iterate, "featureSet").withBroadcastSet(dataSet2, "featureSize").withBroadcastSet(dataSet, "barrier").name("train_and_evaluate").reduce(new ReduceFunction<Tuple3<int[], Double, double[]>>() { // from class: com.alibaba.alink.operator.batch.feature.AutoCrossTrainBatchOp.3
            private static final long serialVersionUID = 1099754368531239834L;

            public Tuple3<int[], Double, double[]> reduce(Tuple3<int[], Double, double[]> tuple3, Tuple3<int[], Double, double[]> tuple32) throws Exception {
                return ((Double) tuple3.f1).doubleValue() > ((Double) tuple32.f1).doubleValue() ? tuple3 : tuple32;
            }
        }).name("reduce the best one");
        return iterate.closeWith(iterate.map(new RichMapFunction<FeatureSet, FeatureSet>() { // from class: com.alibaba.alink.operator.batch.feature.AutoCrossTrainBatchOp.4
            private static final long serialVersionUID = 1017420195682258788L;

            public FeatureSet map(FeatureSet featureSet) {
                List broadcastVariable = getRuntimeContext().getBroadcastVariable("the_one");
                if (broadcastVariable.size() == 0) {
                    return featureSet;
                }
                featureSet.addOneCrossFeature((int[]) ((Tuple3) broadcastVariable.get(0)).f0, ((Double) ((Tuple3) broadcastVariable.get(0)).f1).doubleValue());
                featureSet.updateFixedCoefs((double[]) ((Tuple3) broadcastVariable.get(0)).f2);
                return featureSet;
            }
        }).withBroadcastSet(name, "the_one").name("update feature set"), name).flatMap(new BuildModel("oneHotVectorCol", strArr, this.hasDiscrete)).withBroadcastSet(dataSet2, "featureSize");
    }

    @Override // com.alibaba.alink.operator.batch.feature.BaseCrossTrainBatchOp
    void buildSideOutput(OneHotTrainBatchOp oneHotTrainBatchOp, DataSet<Row> dataSet, List<String> list, long j) {
        setSideOutputTables(new Table[]{DataSetConversionUtil.toTable(Long.valueOf(j), (DataSet<Row>) oneHotTrainBatchOp.getDataSet().mapPartition(new BuildSideOutput(list.size())).withBroadcastSet(dataSet, "autocrossModel").setParallelism(1), new String[]{"index", "feature", "value"}, (TypeInformation<?>[]) new TypeInformation[]{Types.INT, Types.STRING, Types.STRING})});
    }
}
