package com.alibaba.alink.operator.local.feature;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.feature.AutoCross.FeatureEvaluator;
import com.alibaba.alink.operator.common.feature.AutoCross.FeatureSet;
import com.alibaba.alink.operator.common.linear.LinearModelType;
import com.alibaba.alink.operator.local.feature.BaseCrossTrainLocalOp;
import com.alibaba.alink.params.feature.CrossCandidateSelectorTrainParams;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@NameCn("cross候选特征选择训练")
@NameEn("Cross Candidate Selector Train")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.CrossCandidateSelector")
/* loaded from: input_file:com/alibaba/alink/operator/local/feature/CrossCandidateSelectorTrainLocalOp.class */
public class CrossCandidateSelectorTrainLocalOp extends BaseCrossTrainLocalOp<CrossCandidateSelectorTrainLocalOp> implements CrossCandidateSelectorTrainParams<CrossCandidateSelectorTrainLocalOp> {
    public CrossCandidateSelectorTrainLocalOp() {
        this(new Params());
    }

    public CrossCandidateSelectorTrainLocalOp(Params params) {
        super(params);
    }

    @Override // com.alibaba.alink.operator.local.feature.BaseCrossTrainLocalOp
    List<Row> buildAcModelData(List<Tuple3<Double, Double, Vector>> list, int[] iArr, BaseCrossTrainLocalOp.DataColumnsSaver dataColumnsSaver) {
        String[] strArr = dataColumnsSaver.numericalCols;
        List<Tuple4<int[], Double, double[], Integer>> calcAuc = calcAuc(list, iArr, dataColumnsSaver.categoricalCols, dataColumnsSaver.numericalIndices.length, getFeatureCandidates());
        FeatureSet featureSet = new FeatureSet(iArr);
        featureSet.numericalCols = strArr;
        featureSet.indexSize = iArr;
        featureSet.vecColName = "oneHotVectorCol";
        featureSet.hasDiscrete = true;
        calcAuc.sort(new Comparator<Tuple4<int[], Double, double[], Integer>>() { // from class: com.alibaba.alink.operator.local.feature.CrossCandidateSelectorTrainLocalOp.1
            @Override // java.util.Comparator
            public int compare(Tuple4<int[], Double, double[], Integer> tuple4, Tuple4<int[], Double, double[], Integer> tuple42) {
                return -((Double) tuple4.f1).compareTo((Double) tuple42.f1);
            }
        });
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < getCrossFeatureNumber().intValue(); i++) {
            featureSet.addOneCrossFeature((int[]) calcAuc.get(i).f0, ((Double) calcAuc.get(i).f1).doubleValue());
        }
        arrayList.add(Row.of(new Object[]{0L, featureSet.toString(), null}));
        for (int i2 = 0; i2 < featureSet.crossFeatureSet.size(); i2++) {
            arrayList.add(Row.of(new Object[]{Long.valueOf(i2 + 1), JsonConverter.toJson(featureSet.crossFeatureSet.get(i2)), featureSet.scores.get(i2)}));
        }
        return arrayList;
    }

    @Override // com.alibaba.alink.operator.local.feature.BaseCrossTrainLocalOp
    void buildSideOutput(OneHotTrainLocalOp oneHotTrainLocalOp, List<Row> list, List<String> list2) {
    }

    private static List<Tuple4<int[], Double, double[], Integer>> calcAuc(List<Tuple3<Double, Double, Vector>> list, int[] iArr, String[] strArr, int i, String[] strArr2) {
        ArrayList arrayList = new ArrayList(strArr2.length);
        for (String str : strArr2) {
            String[] split = str.split(",");
            for (int i2 = 0; i2 < split.length; i2++) {
                split[i2] = split[i2].trim();
            }
            arrayList.add(TableUtil.findColIndices(strArr, split));
        }
        FeatureEvaluator featureEvaluator = new FeatureEvaluator(LinearModelType.LR, list, iArr, null, 0.8d, false, 1);
        ArrayList arrayList2 = new ArrayList();
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            int[] iArr2 = (int[]) arrayList.get(i3);
            ArrayList arrayList3 = new ArrayList();
            arrayList3.add(iArr2);
            if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println("evaluating " + JsonConverter.toJson(arrayList3));
            }
            Tuple2<Double, double[]> score = featureEvaluator.score(arrayList3, i);
            arrayList2.add(Tuple4.of(iArr2, Double.valueOf(((Double) score.f0).doubleValue()), score.f1, 0));
        }
        return arrayList2;
    }
}
