package com.alibaba.alink.operator.batch.feature;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.feature.BaseCrossTrainBatchOp;
import com.alibaba.alink.operator.common.feature.AutoCross.FeatureEvaluator;
import com.alibaba.alink.operator.common.feature.AutoCross.FeatureSet;
import com.alibaba.alink.operator.common.linear.LinearModelType;
import com.alibaba.alink.operator.common.slidingwindow.SessionSharedData;
import com.alibaba.alink.params.feature.CrossCandidateSelectorTrainParams;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@NameCn("cross候选特征选择训练")
@NameEn("Cross Candidate Selector Train")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.CrossCandidateSelector")
/* loaded from: input_file:com/alibaba/alink/operator/batch/feature/CrossCandidateSelectorTrainBatchOp.class */
public class CrossCandidateSelectorTrainBatchOp extends BaseCrossTrainBatchOp<CrossCandidateSelectorTrainBatchOp> implements CrossCandidateSelectorTrainParams<CrossCandidateSelectorTrainBatchOp> {

    /* loaded from: input_file:com/alibaba/alink/operator/batch/feature/CrossCandidateSelectorTrainBatchOp$CalcAucOfCandidate.class */
    private static class CalcAucOfCandidate extends RichMapPartitionFunction<Tuple3<Double, Double, Vector>, Tuple4<int[], Double, double[], Integer>> {
        private List<int[]> candidateIndices;
        private int[] featureSize;
        private int numTasks;
        private int numericalSize;

        CalcAucOfCandidate(String[] strArr, int i, String[] strArr2) {
            this.numericalSize = i;
            this.candidateIndices = new ArrayList(strArr2.length);
            for (String str : strArr2) {
                String[] split = str.split(",");
                for (int i2 = 0; i2 < split.length; i2++) {
                    split[i2] = split[i2].trim();
                }
                this.candidateIndices.add(TableUtil.findColIndices(strArr, split));
            }
        }

        public void open(Configuration configuration) throws Exception {
            this.numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
            this.featureSize = (int[]) getRuntimeContext().getBroadcastVariable("featureSize").get(0);
        }

        public void mapPartition(Iterable<Tuple3<Double, Double, Vector>> iterable, Collector<Tuple4<int[], Double, double[], Integer>> collector) throws Exception {
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            if (this.candidateIndices.size() <= indexOfThisSubtask) {
                return;
            }
            FeatureEvaluator featureEvaluator = new FeatureEvaluator(LinearModelType.LR, (List) SessionSharedData.get("AC_TRAIN_DATA", AutoCrossTrainBatchOp.SESSION_ID, indexOfThisSubtask), this.featureSize, null, 0.8d, false, 1);
            int i = indexOfThisSubtask;
            while (true) {
                int i2 = i;
                if (i2 >= this.candidateIndices.size()) {
                    return;
                }
                int[] iArr = this.candidateIndices.get(i2);
                ArrayList arrayList = new ArrayList();
                arrayList.add(iArr);
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.println("evaluating " + JsonConverter.toJson(arrayList));
                }
                Tuple2<Double, double[]> score = featureEvaluator.score(arrayList, this.numericalSize);
                collector.collect(Tuple4.of(iArr, Double.valueOf(((Double) score.f0).doubleValue()), score.f1, 0));
                i = i2 + this.numTasks;
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/feature/CrossCandidateSelectorTrainBatchOp$FilterAuc.class */
    private static class FilterAuc extends RichMapPartitionFunction<Tuple4<int[], Double, double[], Integer>, Row> {
        private int crossFeatureNumber;
        private int[] indexSize;
        private String vectorCol;
        private String[] numericalCols;

        FilterAuc(int i, String str, String[] strArr) {
            this.crossFeatureNumber = i;
            this.vectorCol = str;
            this.numericalCols = strArr;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            this.indexSize = (int[]) getRuntimeContext().getBroadcastVariable("featureSize").get(0);
        }

        public void mapPartition(Iterable<Tuple4<int[], Double, double[], Integer>> iterable, Collector<Row> collector) throws Exception {
            ArrayList arrayList = new ArrayList();
            Iterator<Tuple4<int[], Double, double[], Integer>> it = iterable.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next());
            }
            if (arrayList.size() == 0) {
                return;
            }
            FeatureSet featureSet = new FeatureSet(this.indexSize);
            featureSet.numericalCols = this.numericalCols;
            featureSet.indexSize = this.indexSize;
            featureSet.vecColName = this.vectorCol;
            featureSet.hasDiscrete = true;
            arrayList.sort(new Comparator<Tuple4<int[], Double, double[], Integer>>() { // from class: com.alibaba.alink.operator.batch.feature.CrossCandidateSelectorTrainBatchOp.FilterAuc.1
                @Override // java.util.Comparator
                public int compare(Tuple4<int[], Double, double[], Integer> tuple4, Tuple4<int[], Double, double[], Integer> tuple42) {
                    return -((Double) tuple4.f1).compareTo((Double) tuple42.f1);
                }
            });
            for (int i = 0; i < this.crossFeatureNumber; i++) {
                featureSet.addOneCrossFeature((int[]) ((Tuple4) arrayList.get(i)).f0, ((Double) ((Tuple4) arrayList.get(i)).f1).doubleValue());
            }
            collector.collect(Row.of(new Object[]{0L, featureSet.toString(), null}));
            for (int i2 = 0; i2 < featureSet.crossFeatureSet.size(); i2++) {
                collector.collect(Row.of(new Object[]{Long.valueOf(i2 + 1), JsonConverter.toJson(featureSet.crossFeatureSet.get(i2)), featureSet.scores.get(i2)}));
            }
        }
    }

    public CrossCandidateSelectorTrainBatchOp() {
        this(new Params());
    }

    public CrossCandidateSelectorTrainBatchOp(Params params) {
        super(params);
    }

    @Override // com.alibaba.alink.operator.batch.feature.BaseCrossTrainBatchOp
    DataSet<Row> buildAcModelData(DataSet<Tuple3<Double, Double, Vector>> dataSet, DataSet<int[]> dataSet2, BaseCrossTrainBatchOp.DataColumnsSaver dataColumnsSaver) {
        return dataSet.mapPartition(new CalcAucOfCandidate(dataColumnsSaver.categoricalCols, dataColumnsSaver.numericalIndices.length, getFeatureCandidates())).withBroadcastSet(dataSet2, "featureSize").withBroadcastSet(dataSet, "barrier").partitionByHash(new int[]{3}).mapPartition(new FilterAuc(getCrossFeatureNumber().intValue(), "oneHotVectorCol", dataColumnsSaver.numericalCols)).withBroadcastSet(dataSet2, "featureSize");
    }

    @Override // com.alibaba.alink.operator.batch.feature.BaseCrossTrainBatchOp
    void buildSideOutput(OneHotTrainBatchOp oneHotTrainBatchOp, DataSet<Row> dataSet, List<String> list, long j) {
    }
}
