package com.alibaba.alink.operator.local.feature;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.common.feature.AutoCross.DataProfile;
import com.alibaba.alink.operator.common.feature.AutoCross.FeatureEvaluator;
import com.alibaba.alink.operator.common.feature.AutoCross.FeatureSet;
import com.alibaba.alink.operator.common.linear.LinearModelData;
import com.alibaba.alink.operator.common.linear.LinearModelType;
import com.alibaba.alink.operator.common.optim.LocalOptimizer;
import com.alibaba.alink.operator.common.slidingwindow.SessionSharedData;
import com.alibaba.alink.operator.local.AlinkLocalSession;
import com.alibaba.alink.operator.local.feature.BaseCrossTrainLocalOp;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.util.ArrayList;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL), @PortSpec(value = PortType.DATA, desc = PortDesc.CROSSED_FEATURES)})
@NameCn("AutoCross训练")
@NameEn("AutoCross Train")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.AutoCross")
/* loaded from: input_file:com/alibaba/alink/operator/local/feature/AutoCrossTrainLocalOp.class */
public class AutoCrossTrainLocalOp extends BaseCrossTrainLocalOp<AutoCrossTrainLocalOp> {
    public static final String AC_TRAIN_DATA = "AC_TRAIN_DATA";
    public static final int SESSION_ID = SessionSharedData.getNewSessionId();

    public AutoCrossTrainLocalOp() {
        this(new Params());
    }

    public AutoCrossTrainLocalOp(Params params) {
        super(params);
    }

    @Override // com.alibaba.alink.operator.local.feature.BaseCrossTrainLocalOp
    List<Row> buildAcModelData(List<Tuple3<Double, Double, Vector>> list, int[] iArr, BaseCrossTrainLocalOp.DataColumnsSaver dataColumnsSaver) {
        String[] strArr = dataColumnsSaver.numericalCols;
        double doubleValue = getFraction().doubleValue();
        int intValue = getKCross().intValue();
        boolean z = true;
        LinearModelType linearModelType = LinearModelType.LR;
        Tuple2<List<Tuple3<Double, Double, Vector>>, List<Tuple3<Double, Double, Vector>>> split = FeatureEvaluator.split(list, doubleValue, 202303);
        LinearModelData train = FeatureEvaluator.train((List) split.f0, new DataProfile(linearModelType, true));
        DenseVector denseVector = train.coefVector;
        double evaluate = FeatureEvaluator.evaluate(train, (List) split.f1);
        if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
            System.out.println("origin score: " + evaluate);
        }
        Tuple2 of = Tuple2.of(denseVector.getData(), Double.valueOf(evaluate));
        FeatureSet featureSet = new FeatureSet(iArr);
        featureSet.updateFixedCoefs((double[]) of.f0);
        int intValue2 = getMaxSearchStep().intValue();
        int numThreads = LocalOptimizer.getNumThreads(list, getParams());
        for (int i = 0; i < intValue2; i++) {
            AlinkLocalSession.TaskRunner taskRunner = new AlinkLocalSession.TaskRunner();
            List<int[]> generateCandidateCrossFeatures = featureSet.generateCandidateCrossFeatures();
            int size = generateCandidateCrossFeatures.size();
            Tuple3[] tuple3Arr = new Tuple3[size];
            for (int i2 = 0; i2 < numThreads; i2++) {
                int i3 = i2;
                taskRunner.submit(() -> {
                    FeatureEvaluator featureEvaluator = new FeatureEvaluator(linearModelType, list, iArr, featureSet.getFixedCoefs(), doubleValue, z, intValue);
                    int i4 = i3;
                    while (true) {
                        int i5 = i4;
                        if (i5 >= size) {
                            return;
                        }
                        int[] iArr2 = (int[]) generateCandidateCrossFeatures.get(i5);
                        ArrayList arrayList = new ArrayList(featureSet.crossFeatureSet);
                        arrayList.add(iArr2);
                        if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                            System.out.println("curThread:" + i3 + " evaluating " + JsonConverter.toJson(arrayList));
                        }
                        Tuple2<Double, double[]> score = featureEvaluator.score(arrayList, strArr.length);
                        if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                            System.out.println("curThread:" + i3 + ", " + JsonConverter.toJson(arrayList) + " evaluating score: " + score.f0);
                        }
                        tuple3Arr[i5] = Tuple3.of(iArr2, score.f0, score.f1);
                        i4 = i5 + numThreads;
                    }
                });
            }
            taskRunner.join();
            Tuple3 tuple3 = tuple3Arr[0];
            for (int i4 = 1; i4 < tuple3Arr.length; i4++) {
                if (((Double) tuple3Arr[i4].f1).doubleValue() > ((Double) tuple3.f1).doubleValue()) {
                    tuple3 = tuple3Arr[i4];
                }
            }
            if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println("best auc: " + tuple3.f1 + JsonConverter.toJson(tuple3.f0));
            }
            featureSet.addOneCrossFeature((int[]) tuple3.f0, ((Double) tuple3.f1).doubleValue());
            featureSet.updateFixedCoefs((double[]) tuple3.f2);
        }
        ArrayList arrayList = new ArrayList();
        featureSet.numericalCols = strArr;
        featureSet.indexSize = iArr;
        featureSet.vecColName = "oneHotVectorCol";
        featureSet.hasDiscrete = true;
        arrayList.add(Row.of(new Object[]{0L, featureSet.toString(), null}));
        for (int i5 = 0; i5 < featureSet.crossFeatureSet.size(); i5++) {
            arrayList.add(Row.of(new Object[]{Long.valueOf(i5 + 1), JsonConverter.toJson(featureSet.crossFeatureSet.get(i5)), featureSet.scores.get(i5)}));
        }
        return arrayList;
    }

    @Override // com.alibaba.alink.operator.local.feature.BaseCrossTrainLocalOp
    void buildSideOutput(OneHotTrainLocalOp oneHotTrainLocalOp, List<Row> list, List<String> list2) {
    }
}
