package com.alibaba.alink.operator.common.feature.AutoCross;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.common.evaluation.BinaryMetricsSummary;
import com.alibaba.alink.operator.common.evaluation.EvaluationUtil;
import com.alibaba.alink.operator.common.linear.LinearModelData;
import com.alibaba.alink.operator.common.linear.LinearModelDataConverter;
import com.alibaba.alink.operator.common.linear.LinearModelMapper;
import com.alibaba.alink.operator.common.linear.LinearModelType;
import com.alibaba.alink.operator.common.optim.LocalOptimizer;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.local.classification.BaseLinearModelTrainLocalOp;
import com.alibaba.alink.params.classification.LinearModelMapperParams;
import com.alibaba.alink.params.shared.HasNumThreads;
import com.alibaba.alink.params.shared.linear.LinearTrainParams;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/feature/AutoCross/FeatureEvaluator.class */
public class FeatureEvaluator {
    private static final boolean HAS_INTERCEPT = true;
    private LinearModelType linearModelType;
    private final List<Tuple3<Double, Double, Vector>> data;
    private double[] fixedCoefs;
    private int[] featureSize;
    private double fraction;
    private boolean toFixCoef;
    private int kCross;

    public FeatureEvaluator(LinearModelType linearModelType, List<Tuple3<Double, Double, Vector>> list, int[] iArr, double[] dArr, double d, boolean z, int i) {
        this.linearModelType = linearModelType;
        this.data = list;
        this.featureSize = iArr;
        this.fixedCoefs = dArr;
        this.fraction = d;
        this.toFixCoef = z;
        this.kCross = i;
    }

    public Tuple2<Double, double[]> score(List<int[]> list, int i) {
        DataProfile dataProfile = new DataProfile(this.linearModelType, true);
        List<Tuple3<Double, Double, Vector>> expandFeatures = expandFeatures(this.data, list, this.featureSize, i);
        System.out.println(JsonConverter.toJson(list) + ",vector size: " + ((Vector) expandFeatures.get(0).f2).size());
        LinearModelData linearModelData = null;
        double d = 0.0d;
        for (int i2 = 0; i2 < this.kCross; i2++) {
            Tuple2<List<Tuple3<Double, Double, Vector>>, List<Tuple3<Double, Double, Vector>>> split = split(expandFeatures, this.fraction, i2);
            linearModelData = train((List) split.f0, dataProfile, this.toFixCoef, this.fixedCoefs);
            d += evaluate(linearModelData, (List) split.f1);
        }
        return Tuple2.of(Double.valueOf(d / this.kCross), linearModelData.coefVector.getData());
    }

    public static LinearModelData train(List<Tuple3<Double, Double, Vector>> list, DataProfile dataProfile, boolean z, double[] dArr) {
        return train(list, dataProfile);
    }

    public static LinearModelData train(List<Tuple3<Double, Double, Vector>> list, DataProfile dataProfile) {
        LinearModelType linearModelType = LinearModelType.LR;
        OptimObjFunc objFunction = OptimObjFunc.getObjFunction(linearModelType, new Params());
        boolean equals = linearModelType.equals(LinearModelType.LR);
        List list2 = (List) list.stream().map(tuple3 -> {
            Vector prefix = dataProfile.hasIntercept ? ((Vector) tuple3.f2).prefix(1.0d) : (Vector) tuple3.f2;
            double doubleValue = ((Double) tuple3.f1).doubleValue();
            if (equals) {
                doubleValue = doubleValue == Criteria.INVALID_GAIN ? 1.0d : -1.0d;
            }
            return Tuple3.of(tuple3.f0, Double.valueOf(doubleValue), prefix);
        }).collect(Collectors.toList());
        Params params = new Params().set((ParamInfo<ParamInfo<Integer>>) HasNumThreads.NUM_THREADS, (ParamInfo<Integer>) 1).set((ParamInfo<ParamInfo<Boolean>>) LinearTrainParams.WITH_INTERCEPT, (ParamInfo<Boolean>) Boolean.valueOf(dataProfile.hasIntercept)).set((ParamInfo<ParamInfo<Boolean>>) LinearTrainParams.STANDARDIZATION, (ParamInfo<Boolean>) false);
        double[] dArr = new double[((Vector) ((Tuple3) list2.get(0)).f2).size()];
        Arrays.fill(dArr, 1.0E-4d);
        Tuple2<DenseVector, double[]> optimize = LocalOptimizer.optimize(objFunction, list2, new DenseVector(dArr), params);
        Params params2 = new Params();
        Double[] dArr2 = new Double[dataProfile.numDistinctLabels];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = Double.valueOf(i);
        }
        params2.set((ParamInfo<ParamInfo<String>>) ModelParamName.MODEL_NAME, (ParamInfo<String>) OptimVariable.model);
        params2.set((ParamInfo<ParamInfo<LinearModelType>>) ModelParamName.LINEAR_MODEL_TYPE, (ParamInfo<LinearModelType>) linearModelType);
        params2.set((ParamInfo<ParamInfo<Object[]>>) ModelParamName.LABEL_VALUES, (ParamInfo<Object[]>) dArr2);
        params2.set((ParamInfo<ParamInfo<Boolean>>) ModelParamName.HAS_INTERCEPT_ITEM, (ParamInfo<Boolean>) Boolean.valueOf(dataProfile.hasIntercept));
        params2.set((ParamInfo<ParamInfo<String>>) ModelParamName.VECTOR_COL_NAME, (ParamInfo<String>) "features");
        params2.set((ParamInfo<ParamInfo<String>>) LinearTrainParams.LABEL_COL, (ParamInfo<String>) null);
        params2.set((ParamInfo<ParamInfo<String[]>>) ModelParamName.FEATURE_TYPES, (ParamInfo<String[]>) null);
        return BaseLinearModelTrainLocalOp.buildLinearModelData(params2, null, Types.DOUBLE, null, dataProfile.hasIntercept, false, optimize);
    }

    public static double evaluate(LinearModelData linearModelData, List<Tuple3<Double, Double, Vector>> list) {
        LinearModelMapper linearModelMapper = new LinearModelMapper(new LinearModelDataConverter(Types.DOUBLE).getModelSchema(), new TableSchema(new String[]{"features", "label"}, new TypeInformation[]{Types.STRING, Types.DOUBLE}), new Params().set((ParamInfo<ParamInfo<String>>) LinearModelMapperParams.VECTOR_COL, (ParamInfo<String>) "features").set((ParamInfo<ParamInfo<String>>) LinearModelMapperParams.PREDICTION_COL, (ParamInfo<String>) "prediction_result").set((ParamInfo<ParamInfo<String>>) LinearModelMapperParams.PREDICTION_DETAIL_COL, (ParamInfo<String>) "prediction_detail").set((ParamInfo<ParamInfo<String[]>>) LinearModelMapperParams.RESERVED_COLS, (ParamInfo<String[]>) new String[]{"label"}));
        linearModelMapper.loadModel(linearModelData);
        List list2 = (List) list.stream().map(tuple3 -> {
            try {
                Row map = linearModelMapper.map(Row.of(new Object[]{tuple3.f2, tuple3.f1}));
                return Row.of(new Object[]{map.getField(0), map.getField(2)});
            } catch (Exception e) {
                throw new RuntimeException("Fail to predict.", e);
            }
        }).collect(Collectors.toList());
        if (linearModelData.linearModelType.equals(LinearModelType.LR)) {
            return ((BinaryMetricsSummary) EvaluationUtil.getDetailStatistics((Iterable<Row>) list2, "1.0", true, Types.DOUBLE)).toMetrics().getAuc().doubleValue();
        }
        throw new UnsupportedOperationException("Not yet supported model type: " + linearModelData.linearModelType);
    }

    public static Tuple2<List<Tuple3<Double, Double, Vector>>, List<Tuple3<Double, Double, Vector>>> split(List<Tuple3<Double, Double, Vector>> list, double d, int i) {
        ArrayList arrayList = new ArrayList((list.size() / 2) + 1);
        ArrayList arrayList2 = new ArrayList((list.size() / 2) + 1);
        if (list.size() < 2) {
            throw new RuntimeException("Data size is too small!");
        }
        arrayList.add(list.get(0));
        arrayList2.add(list.get(1));
        Random random = new Random(i);
        for (int i2 = 2; i2 < list.size(); i2++) {
            if (random.nextDouble() <= d) {
                arrayList.add(list.get(i2));
            } else {
                arrayList2.add(list.get(i2));
            }
        }
        return Tuple2.of(arrayList, arrayList2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static List<Tuple3<Double, Double, Vector>> expandFeatures(List<Tuple3<Double, Double, Vector>> list, List<int[]> list2, int[] iArr, int i) {
        int length = iArr.length;
        int size = ((Vector) list.get(0).f2).size();
        int size2 = list2.size();
        int[] iArr2 = new int[iArr.length];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (i2 == 0) {
                iArr2[i2] = 0;
            } else {
                iArr2[i2] = iArr2[i2 - 1] + iArr[i2 - 1];
            }
        }
        int[] iArr3 = new int[size2];
        for (int i3 = 0; i3 < size2; i3++) {
            int[] iArr4 = list2.get(i3);
            for (int i4 = 0; i4 < iArr4.length; i4++) {
                if (i4 == 0) {
                    iArr3[i3] = new int[iArr4.length];
                    iArr3[i3][i4] = 1;
                } else {
                    iArr3[i3][i4] = iArr3[i3][i4 - 1] * iArr[iArr4[i4 - 1]];
                }
            }
        }
        ArrayList arrayList = new ArrayList(list.size());
        int[] iArr5 = new int[size2];
        int[][] iArr6 = new int[list.size()][size2];
        for (int i5 = 0; i5 < size2; i5++) {
            for (int i6 = 0; i6 < list.size(); i6++) {
                int[] iArr7 = (int[]) ((SparseVector) list.get(i6).f2).getIndices().clone();
                for (int i7 = i; i7 < iArr7.length; i7++) {
                    int i8 = i7;
                    iArr7[i8] = iArr7[i8] - i;
                }
                int dot = dot(iArr3[i5], list2.get(i5), iArr7, i, iArr2);
                if (dot < 0) {
                    System.out.println();
                }
                iArr6[i6][i5] = dot;
            }
            int i9 = 1;
            for (int i10 : list2.get(i5)) {
                i9 *= iArr[i10];
            }
            iArr5[i5] = i9;
        }
        int[] iArr8 = new int[size2];
        iArr8[0] = size;
        for (int i11 = 1; i11 < size2; i11++) {
            iArr8[i11] = iArr8[i11 - 1] + iArr5[i11 - 1];
        }
        for (int i12 = 0; i12 < list.size(); i12++) {
            for (int i13 = 0; i13 < size2; i13++) {
                int[] iArr9 = iArr6[i12];
                int i14 = i13;
                iArr9[i14] = iArr9[i14] + iArr8[i13];
            }
        }
        int i15 = iArr8[size2 - 1] + iArr5[size2 - 1];
        int i16 = i + length + size2;
        for (int i17 = 0; i17 < list.size(); i17++) {
            Tuple3<Double, Double, Vector> tuple3 = list.get(i17);
            int[] indices = ((SparseVector) tuple3.f2).getIndices();
            int[] iArr10 = new int[i16];
            System.arraycopy(indices, 0, iArr10, 0, length + i);
            System.arraycopy(iArr6[i17], 0, iArr10, length + i, size2);
            double[] dArr = new double[i16];
            Arrays.fill(dArr, 1.0d);
            System.arraycopy(((SparseVector) tuple3.f2).getValues(), 0, dArr, 0, i);
            arrayList.add(Tuple3.of(tuple3.f0, tuple3.f1, new SparseVector(i15, iArr10, dArr)));
        }
        return arrayList;
    }

    private static int dot(int[] iArr, int[] iArr2, int[] iArr3, int i, int[] iArr4) {
        int i2 = 0;
        for (int i3 = 0; i3 < iArr.length; i3++) {
            i2 += iArr[i3] * (iArr3[iArr2[i3] + i] - iArr4[iArr2[i3]]);
        }
        return i2;
    }
}
