package com.alibaba.alink.operator.common.finance.stepwiseSelector;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.common.viz.AlinkViz;
import com.alibaba.alink.common.viz.VizDataWriterInterface;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.finance.ScorecardTrainBatchOp;
import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.batch.utils.DataSetUtil;
import com.alibaba.alink.operator.common.dataproc.vector.VectorAssemblerMapper;
import com.alibaba.alink.operator.common.feature.SelectorModelData;
import com.alibaba.alink.operator.common.feature.SelectorModelDataConverter;
import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp;
import com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp;
import com.alibaba.alink.operator.common.linear.LinearModelDataConverter;
import com.alibaba.alink.operator.common.linear.LinearModelType;
import com.alibaba.alink.operator.common.linear.LinearRegressionSummary;
import com.alibaba.alink.operator.common.linear.LocalLinearModel;
import com.alibaba.alink.operator.common.linear.LogistRegressionSummary;
import com.alibaba.alink.operator.common.linear.ModelSummary;
import com.alibaba.alink.operator.common.linear.ModelSummaryHelper;
import com.alibaba.alink.operator.common.optim.FeatureConstraint;
import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable;
import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummarizer;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.finance.BaseStepwiseSelectorParams;
import com.alibaba.alink.params.finance.ConstrainedLogisticRegressionTrainParams;
import com.alibaba.alink.params.finance.HasConstrainedLinearModelType;
import com.alibaba.alink.params.shared.linear.LinearTrainParams;
import java.security.InvalidParameterException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/finance/stepwiseSelector/BaseStepWiseSelectorBatchOp.class */
public class BaseStepWiseSelectorBatchOp extends BatchOperator<BaseStepWiseSelectorBatchOp> implements BaseStepwiseSelectorParams<BaseStepWiseSelectorBatchOp>, AlinkViz<BaseStepWiseSelectorBatchOp> {
    private static final long serialVersionUID = -1353820179732001005L;
    private static final String INNER_VECTOR_COL = "vec";
    private static final String INNER_LABLE_COL = "label";
    private BatchOperator in;
    private DataSet<Object> labels;
    private boolean hasConstraint;
    private DataSet<Row> constraintDataSet;
    private boolean hasVectorSizes;
    private DataSet<int[]> vectorSizes;
    private String selectColNew;
    private String labelColNew;
    private int selectedColIdxNew;
    private int labelIdxNew;
    private String selectedCol;
    private String[] selectedCols;
    private HasConstrainedLinearModelType.LinearModelType linearModelType;
    private boolean inScorcard;

    /* loaded from: input_file:com/alibaba/alink/operator/common/finance/stepwiseSelector/BaseStepWiseSelectorBatchOp$BuildLinearModel.class */
    public static class BuildLinearModel extends RichFlatMapFunction<Row, Row> {
        private HasConstrainedLinearModelType.LinearModelType linearModelType;
        private String[] selectedCols;
        private TypeInformation[] selectedColsType;
        private String labelCol;
        private TypeInformation labelType;
        private Object[] labelValues;
        private String positiveLabel;
        private boolean inScorecard;

        public BuildLinearModel(HasConstrainedLinearModelType.LinearModelType linearModelType, String[] strArr, TypeInformation[] typeInformationArr, String str, TypeInformation typeInformation, String str2, boolean z) {
            this.linearModelType = linearModelType;
            this.selectedCols = strArr;
            this.selectedColsType = typeInformationArr;
            this.labelCol = str;
            this.labelType = typeInformation;
            this.positiveLabel = str2;
            this.inScorecard = z;
        }

        public void open(Configuration configuration) {
            if (HasConstrainedLinearModelType.LinearModelType.LR == this.linearModelType || this.inScorecard) {
                this.labelValues = ModelSummaryHelper.orderLabels(getRuntimeContext().getBroadcastVariable(BaseFmTrainBatchOp.LABEL_VALUES), this.positiveLabel);
            }
        }

        public void flatMap(Row row, Collector<Row> collector) throws Exception {
            int[] iArr;
            DenseVector denseVector;
            if (HasConstrainedLinearModelType.LinearModelType.LR == this.linearModelType) {
                ClassificationSelectorResult classificationSelectorResult = (ClassificationSelectorResult) JsonConverter.fromJson((String) row.getField(0), ClassificationSelectorResult.class);
                iArr = classificationSelectorResult.selectedIndices;
                denseVector = classificationSelectorResult.modelSummary.beta;
            } else {
                RegressionSelectorResult regressionSelectorResult = (RegressionSelectorResult) JsonConverter.fromJson((String) row.getField(0), RegressionSelectorResult.class);
                iArr = regressionSelectorResult.selectedIndices;
                denseVector = regressionSelectorResult.modelSummary.beta;
            }
            String[] sCurSelectedCols = BaseStepWiseSelectorBatchOp.getSCurSelectedCols(this.selectedCols, iArr);
            String[] strArr = new String[sCurSelectedCols.length];
            for (int i = 0; i < sCurSelectedCols.length; i++) {
                strArr[i] = this.selectedColsType[iArr[i]].getTypeClass().getSimpleName();
            }
            new LinearModelDataConverter(this.labelType).save(BaseLinearModelTrainBatchOp.buildLinearModelData(new Params().set((ParamInfo<ParamInfo<String>>) ModelParamName.MODEL_NAME, (ParamInfo<String>) OptimVariable.model).set((ParamInfo<ParamInfo<LinearModelType>>) ModelParamName.LINEAR_MODEL_TYPE, (ParamInfo<LinearModelType>) LinearModelType.valueOf(this.linearModelType.name())).set((ParamInfo<ParamInfo<Object[]>>) ModelParamName.LABEL_VALUES, (ParamInfo<Object[]>) this.labelValues).set((ParamInfo<ParamInfo<Boolean>>) ModelParamName.HAS_INTERCEPT_ITEM, (ParamInfo<Boolean>) true).set((ParamInfo<ParamInfo<String[]>>) ModelParamName.FEATURE_TYPES, (ParamInfo<String[]>) strArr).set((ParamInfo<ParamInfo<String>>) LinearTrainParams.LABEL_COL, (ParamInfo<String>) this.labelCol), sCurSelectedCols, this.labelType, null, true, false, Tuple2.of(denseVector, new double[]{Criteria.INVALID_GAIN})), collector);
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Row) obj, (Collector<Row>) collector);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/finance/stepwiseSelector/BaseStepWiseSelectorBatchOp$BuildModel.class */
    public static class BuildModel implements FlatMapFunction<Row, Row> {
        private static final long serialVersionUID = -4792429339624354557L;
        private String selectedCol;
        private String[] selectedCols;
        private HasConstrainedLinearModelType.LinearModelType linearModelType;

        BuildModel(String str, String[] strArr, HasConstrainedLinearModelType.LinearModelType linearModelType) {
            this.selectedCol = str;
            this.selectedCols = strArr;
            this.linearModelType = linearModelType;
        }

        public void flatMap(Row row, Collector<Row> collector) throws Exception {
            SelectorModelData selectorModelData = new SelectorModelData();
            if (HasConstrainedLinearModelType.LinearModelType.LR == this.linearModelType) {
                ClassificationSelectorResult classificationSelectorResult = (ClassificationSelectorResult) JsonConverter.fromJson((String) row.getField(0), ClassificationSelectorResult.class);
                selectorModelData.vectorColName = this.selectedCol;
                selectorModelData.selectedIndices = classificationSelectorResult.selectedIndices;
                selectorModelData.vectorColNames = BaseStepWiseSelectorBatchOp.getSCurSelectedCols(this.selectedCols, classificationSelectorResult.selectedIndices);
            } else {
                RegressionSelectorResult regressionSelectorResult = (RegressionSelectorResult) JsonConverter.fromJson((String) row.getField(0), RegressionSelectorResult.class);
                selectorModelData.vectorColName = this.selectedCol;
                selectorModelData.selectedIndices = regressionSelectorResult.selectedIndices;
                selectorModelData.vectorColNames = BaseStepWiseSelectorBatchOp.getSCurSelectedCols(this.selectedCols, regressionSelectorResult.selectedIndices);
            }
            new SelectorModelDataConverter().save(selectorModelData, collector);
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Row) obj, (Collector<Row>) collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/finance/stepwiseSelector/BaseStepWiseSelectorBatchOp$CalcVectorSize.class */
    public static class CalcVectorSize implements MapPartitionFunction<Row, int[]> {
        private static final long serialVersionUID = -4671985561279422505L;
        private int[] selectedColIndices;

        CalcVectorSize(int[] iArr) {
            this.selectedColIndices = iArr;
        }

        public void mapPartition(Iterable<Row> iterable, Collector<int[]> collector) throws Exception {
            int numberOfValues;
            int[] iArr = new int[this.selectedColIndices.length];
            Arrays.fill(iArr, 0);
            for (Row row : iterable) {
                for (int i = 0; i < iArr.length; i++) {
                    Vector vector = VectorUtil.getVector(row.getField(this.selectedColIndices[i]));
                    if (vector instanceof DenseVector) {
                        numberOfValues = vector.size();
                    } else {
                        SparseVector sparseVector = (SparseVector) vector;
                        numberOfValues = sparseVector.size() == -1 ? sparseVector.numberOfValues() : sparseVector.size();
                    }
                    iArr[i] = Math.max(iArr[i], numberOfValues);
                }
            }
            collector.collect(iArr);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/finance/stepwiseSelector/BaseStepWiseSelectorBatchOp$ProcViz.class */
    public static class ProcViz implements FlatMapFunction<Row, Row> {
        private static final long serialVersionUID = -2466978695333617548L;
        private HasConstrainedLinearModelType.LinearModelType linearModelType;
        private String[] selectedCols;
        private VizDataWriterInterface writer;

        ProcViz(HasConstrainedLinearModelType.LinearModelType linearModelType, String[] strArr, VizDataWriterInterface vizDataWriterInterface) {
            this.linearModelType = linearModelType;
            this.selectedCols = strArr;
            this.writer = vizDataWriterInterface;
        }

        public void flatMap(Row row, Collector<Row> collector) throws Exception {
            String vizData;
            String str = (String) row.getField(0);
            if (HasConstrainedLinearModelType.LinearModelType.LR == this.linearModelType) {
                ClassificationSelectorResult classificationSelectorResult = (ClassificationSelectorResult) JsonConverter.fromJson(str, ClassificationSelectorResult.class);
                classificationSelectorResult.selectedCols = BaseStepWiseSelectorBatchOp.getSCurSelectedCols(this.selectedCols, classificationSelectorResult.selectedIndices);
                vizData = classificationSelectorResult.toVizData();
            } else {
                RegressionSelectorResult regressionSelectorResult = (RegressionSelectorResult) JsonConverter.fromJson(str, RegressionSelectorResult.class);
                regressionSelectorResult.selectedCols = BaseStepWiseSelectorBatchOp.getSCurSelectedCols(this.selectedCols, regressionSelectorResult.selectedIndices);
                vizData = regressionSelectorResult.toVizData();
            }
            this.writer.writeBatchData(0L, vizData, System.currentTimeMillis());
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Row) obj, (Collector<Row>) collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/finance/stepwiseSelector/BaseStepWiseSelectorBatchOp$StepWiseMapPartition.class */
    public static class StepWiseMapPartition extends RichMapPartitionFunction<Tuple2<Vector, Row>, Row> {
        private static final long serialVersionUID = -3204445094879081660L;
        private int featureSize;
        private boolean hasVectorSizes;
        private int[] vectorSizes;
        private BaseVectorSummarizer summary;
        private List<Tuple3<Double, Double, Vector>> trainData;
        private StepWiseType stepwiseType;
        private double alphaEntry;
        private double alphaStay;
        private int[] forceSelectedIndices;
        private LinearModelType linearModelType;
        private String optimMethod;
        private double l1;
        private double l2;
        private boolean hasConstraint;
        private FeatureConstraint constraints;

        StepWiseMapPartition(int[] iArr, double d, double d2, HasConstrainedLinearModelType.LinearModelType linearModelType, String str, StepWiseType stepWiseType, double d3, double d4, boolean z, boolean z2) {
            this.forceSelectedIndices = iArr;
            this.alphaEntry = d;
            this.alphaStay = d2;
            this.linearModelType = LinearModelType.valueOf(linearModelType.name());
            this.optimMethod = str.toUpperCase().trim();
            this.stepwiseType = stepWiseType;
            this.l1 = d3;
            this.l2 = d4;
            this.hasVectorSizes = z;
            this.hasConstraint = z2;
        }

        public void open(Configuration configuration) {
            this.summary = (BaseVectorSummarizer) getRuntimeContext().getBroadcastVariable("summarizer").get(0);
            if (this.hasVectorSizes) {
                int[] iArr = (int[]) getRuntimeContext().getBroadcastVariable("vectorSizes").get(0);
                this.vectorSizes = new int[iArr.length + 1];
                this.vectorSizes[0] = 0;
                for (int i = 0; i < iArr.length; i++) {
                    this.vectorSizes[i + 1] = this.vectorSizes[i] + iArr[i];
                }
            }
            String str = null;
            if (this.hasConstraint) {
                Object field = ((Row) getRuntimeContext().getBroadcastVariable("constraint").get(0)).getField(0);
                str = field instanceof FeatureConstraint ? field.toString() : (String) field;
            }
            this.constraints = FeatureConstraint.fromJson(str);
        }

        public void mapPartition(Iterable<Tuple2<Vector, Row>> iterable, Collector<Row> collector) throws Exception {
            this.trainData = transformData(iterable);
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            ArrayList<Integer> arrayList3 = new ArrayList();
            for (int i : this.forceSelectedIndices) {
                arrayList.add(Integer.valueOf(i));
                arrayList2.addAll(BaseStepWiseSelectorBatchOp.getOne(i, this.vectorSizes));
            }
            for (int i2 = 0; i2 < this.featureSize; i2++) {
                if (!arrayList.contains(Integer.valueOf(i2))) {
                    arrayList3.add(Integer.valueOf(i2));
                }
            }
            ModelSummary modelSummary = null;
            ArrayList arrayList4 = new ArrayList();
            if (this.forceSelectedIndices != null && this.forceSelectedIndices.length != 0) {
                modelSummary = train(BaseStepWiseSelectorBatchOp.getIndicesFromList(arrayList2));
            }
            ModelSummary train = isLinearRegression() ? train(null) : null;
            while (!arrayList3.isEmpty()) {
                int i3 = -1;
                ModelSummary modelSummary2 = null;
                double d = Double.NEGATIVE_INFINITY;
                for (Integer num : arrayList3) {
                    ArrayList arrayList5 = new ArrayList();
                    arrayList5.addAll(arrayList2);
                    arrayList5.addAll(BaseStepWiseSelectorBatchOp.getOne(num.intValue(), this.vectorSizes));
                    ModelSummary train2 = train(BaseStepWiseSelectorBatchOp.getIndicesFromList(arrayList5));
                    Tuple2 forwardValue = BaseStepWiseSelectorBatchOp.getForwardValue(train2, modelSummary, this.stepwiseType);
                    if (((Double) forwardValue.f0).doubleValue() > d && (arrayList.size() == 0 || ((Double) forwardValue.f1).doubleValue() <= this.alphaEntry)) {
                        d = ((Double) forwardValue.f0).doubleValue();
                        modelSummary2 = train2;
                        i3 = num.intValue();
                    }
                }
                if (i3 < 0) {
                    break;
                }
                arrayList2.addAll(BaseStepWiseSelectorBatchOp.getOne(i3, this.vectorSizes));
                arrayList.add(Integer.valueOf(i3));
                arrayList3.remove(arrayList3.indexOf(Integer.valueOf(i3)));
                ArrayList arrayList6 = new ArrayList();
                if (arrayList.size() > 1) {
                    double[] backwardValues = BaseStepWiseSelectorBatchOp.getBackwardValues(modelSummary2, this.summary, this.stepwiseType, this.vectorSizes, arrayList);
                    int argmax = BaseStepWiseSelectorBatchOp.argmax(backwardValues);
                    if (backwardValues[argmax] >= this.alphaStay && !BaseStepWiseSelectorBatchOp.isIdxExist(this.forceSelectedIndices, ((Integer) arrayList.get(argmax)).intValue()).booleanValue()) {
                        int intValue = ((Integer) arrayList.get(argmax)).intValue();
                        arrayList6.add(Integer.valueOf(intValue));
                        arrayList.remove(arrayList.indexOf(Integer.valueOf(intValue)));
                        arrayList2.removeAll(BaseStepWiseSelectorBatchOp.getOne(intValue, this.vectorSizes));
                        arrayList3.add(Integer.valueOf(intValue));
                    }
                }
                if (arrayList6.size() == 1 && ((Integer) arrayList6.get(0)).equals(Integer.valueOf(i3))) {
                    break;
                }
                modelSummary = calMallowCp(modelSummary2, train);
                arrayList4.add(modelSummary.toSelectStep(i3));
                Iterator it = arrayList6.iterator();
                while (it.hasNext()) {
                    int intValue2 = ((Integer) it.next()).intValue();
                    int i4 = 0;
                    while (true) {
                        if (i4 >= arrayList4.size()) {
                            break;
                        }
                        if (intValue2 == Integer.valueOf(arrayList4.get(i4).enterCol).intValue()) {
                            arrayList4.remove(i4);
                            break;
                        }
                        i4++;
                    }
                }
                if (arrayList6.size() != 0) {
                    modelSummary = train(BaseStepWiseSelectorBatchOp.getIndicesFromList(arrayList));
                }
            }
            Row row = new Row(1);
            row.setField(0, bestModelResult(arrayList4, modelSummary));
            collector.collect(row);
        }

        private ModelSummary calMallowCp(ModelSummary modelSummary, ModelSummary modelSummary2) {
            if (isLinearRegression()) {
                LinearRegressionSummary linearRegressionSummary = (LinearRegressionSummary) modelSummary;
                linearRegressionSummary.mallowCp = ((((linearRegressionSummary.count - this.featureSize) - 1) * (linearRegressionSummary.sse / ((LinearRegressionSummary) modelSummary2).sse)) - linearRegressionSummary.count) + (2 * (linearRegressionSummary.lowerConfidence.length + 1));
            }
            return modelSummary;
        }

        private boolean isLinearRegression() {
            return this.linearModelType == LinearModelType.LinearReg;
        }

        private String bestModelResult(List<SelectorStep> list, ModelSummary modelSummary) {
            if (isLinearRegression()) {
                RegressionSelectorResult regressionSelectorResult = new RegressionSelectorResult();
                regressionSelectorResult.entryVars = new RegressionSelectorStep[list.size()];
                regressionSelectorResult.selectedIndices = new int[list.size() + this.forceSelectedIndices.length];
                System.arraycopy(this.forceSelectedIndices, 0, regressionSelectorResult.selectedIndices, 0, this.forceSelectedIndices.length);
                for (int i = 0; i < list.size(); i++) {
                    regressionSelectorResult.entryVars[i] = (RegressionSelectorStep) list.get(i);
                    regressionSelectorResult.selectedIndices[i + this.forceSelectedIndices.length] = Integer.valueOf(regressionSelectorResult.entryVars[i].enterCol).intValue();
                }
                regressionSelectorResult.modelSummary = (LinearRegressionSummary) modelSummary;
                return JsonConverter.toJson(regressionSelectorResult);
            }
            ClassificationSelectorResult classificationSelectorResult = new ClassificationSelectorResult();
            classificationSelectorResult.entryVars = new ClassificationSelectorStep[list.size()];
            classificationSelectorResult.selectedIndices = new int[list.size() + this.forceSelectedIndices.length];
            System.arraycopy(this.forceSelectedIndices, 0, classificationSelectorResult.selectedIndices, 0, this.forceSelectedIndices.length);
            for (int i2 = 0; i2 < list.size(); i2++) {
                classificationSelectorResult.entryVars[i2] = (ClassificationSelectorStep) list.get(i2);
                classificationSelectorResult.selectedIndices[i2 + this.forceSelectedIndices.length] = Integer.valueOf(classificationSelectorResult.entryVars[i2].enterCol).intValue();
            }
            classificationSelectorResult.modelSummary = (LogistRegressionSummary) modelSummary;
            return JsonConverter.toJson(classificationSelectorResult);
        }

        private List<Tuple3<Double, Double, Vector>> transformData(Iterable<Tuple2<Vector, Row>> iterable) {
            if (this.hasVectorSizes) {
                this.featureSize = this.vectorSizes.length - 1;
            } else {
                this.featureSize = this.summary.toSummary().vectorSize() - 1;
            }
            ArrayList arrayList = new ArrayList();
            for (Tuple2<Vector, Row> tuple2 : iterable) {
                if (this.vectorSizes == null && (tuple2.f0 instanceof SparseVector)) {
                    ((SparseVector) tuple2.f0).setSize(this.featureSize);
                }
                arrayList.add(Tuple3.of(Double.valueOf(1.0d), Double.valueOf(((Number) ((Row) tuple2.f1).getField(0)).doubleValue()), tuple2.f0));
            }
            return arrayList;
        }

        private ModelSummary train(int[] iArr) {
            return LocalLinearModel.trainWithSummary(this.trainData, iArr, this.linearModelType, this.optimMethod, true, false, iArr == null ? this.constraints.toString() : this.constraints.extractConstraint(iArr), this.l1, this.l2, this.summary);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/finance/stepwiseSelector/BaseStepWiseSelectorBatchOp$ToClassificationSelectorResult.class */
    public static class ToClassificationSelectorResult implements MapFunction<Row, SelectorResult> {
        private static final long serialVersionUID = 382487577293357907L;
        private String[] selectedCols;
        private HasConstrainedLinearModelType.LinearModelType linearModelType;

        ToClassificationSelectorResult(HasConstrainedLinearModelType.LinearModelType linearModelType, String[] strArr) {
            this.selectedCols = strArr;
            this.linearModelType = linearModelType;
        }

        public SelectorResult map(Row row) throws Exception {
            SelectorResult selectorResult = HasConstrainedLinearModelType.LinearModelType.LR == this.linearModelType ? (SelectorResult) JsonConverter.fromJson((String) row.getField(0), ClassificationSelectorResult.class) : (SelectorResult) JsonConverter.fromJson((String) row.getField(0), RegressionSelectorResult.class);
            selectorResult.selectedCols = BaseStepWiseSelectorBatchOp.getSCurSelectedCols(this.selectedCols, selectorResult.selectedIndices);
            return selectorResult;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/finance/stepwiseSelector/BaseStepWiseSelectorBatchOp$ToVectorWithReservedCols.class */
    public static class ToVectorWithReservedCols implements MapFunction<Row, Vector> {
        private static final long serialVersionUID = 5163307315870607698L;
        private int vectorColIdx;
        private int labelColIdx;

        public ToVectorWithReservedCols(int i, int i2) {
            this.vectorColIdx = i;
            this.labelColIdx = i2;
        }

        public Vector map(Row row) throws Exception {
            Vector vector = VectorUtil.getVector(row.getField(this.vectorColIdx));
            if (vector == null) {
                throw new RuntimeException("vector is null, please check your input data.");
            }
            return vector.prefix(((Number) row.getField(this.labelColIdx)).doubleValue());
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/finance/stepwiseSelector/BaseStepWiseSelectorBatchOp$VectorAssembler.class */
    public static class VectorAssembler extends RichMapFunction<Row, Row> {
        private static final long serialVersionUID = 2474917145545199423L;
        private int[] vectorSizes;
        private int[] selectedColIndices;
        private int labelColIdx;

        public VectorAssembler(int[] iArr, int i) {
            this.selectedColIndices = iArr;
            this.labelColIdx = i;
        }

        public void open(Configuration configuration) {
            this.vectorSizes = (int[]) getRuntimeContext().getBroadcastVariable("vectorSizes").get(0);
        }

        public Row map(Row row) throws Exception {
            int length = this.selectedColIndices.length;
            Object[] objArr = new Object[length];
            for (int i = 0; i < length; i++) {
                Vector vector = VectorUtil.getVector(row.getField(this.selectedColIndices[i]));
                if (vector instanceof SparseVector) {
                    ((SparseVector) vector).setSize(this.vectorSizes[i]);
                }
                objArr[i] = vector;
            }
            Row row2 = new Row(2);
            row2.setField(0, VectorAssemblerMapper.assembler(objArr));
            row2.setField(1, row.getField(this.labelColIdx));
            return row2;
        }
    }

    public BaseStepWiseSelectorBatchOp(Params params) {
        super(params);
        this.vectorSizes = null;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public BaseStepWiseSelectorBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        if (batchOperatorArr.length != 2 && batchOperatorArr.length != 1) {
            throw new InvalidParameterException("input size must be one or two.");
        }
        this.linearModelType = getLinearModelType();
        this.inScorcard = ((Boolean) getParams().get(ScorecardTrainBatchOp.IN_SCORECARD)).booleanValue();
        this.in = batchOperatorArr[0];
        BatchOperator<?> batchOperator = null;
        if (batchOperatorArr.length == 2) {
            batchOperator = batchOperatorArr[1];
        }
        int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(this.in.getColNames(), getLabelCol());
        TypeInformation<?> findColTypeWithAssertAndHint = TableUtil.findColTypeWithAssertAndHint(this.in.getSchema(), getLabelCol());
        int[] iArr = new int[0];
        if (getParams().contains(BaseStepwiseSelectorParams.FORCE_SELECTED_COLS)) {
            iArr = getForceSelectedCols();
        }
        String str = null;
        if ((HasConstrainedLinearModelType.LinearModelType.LR == this.linearModelType || this.inScorcard) && getParams().contains(ConstrainedLogisticRegressionTrainParams.POS_LABEL_VAL_STR)) {
            str = (String) getParams().get(ConstrainedLogisticRegressionTrainParams.POS_LABEL_VAL_STR);
        }
        standardConstraint(batchOperator);
        standardLabel(str);
        transformToVector(findColIndexWithAssertAndHint, findColTypeWithAssertAndHint, this.linearModelType);
        Operator parallelism = StatisticsHelper.transformToVector(this.in, null, this.selectColNew, new String[]{this.labelColNew}).mapPartition(new StepWiseMapPartition(iArr, getAlphaEntry().doubleValue(), getAlphaStay().doubleValue(), getLinearModelType(), getOptimMethod(), getStepWiseType(), getL1().doubleValue(), getL2().doubleValue(), this.hasVectorSizes, this.hasConstraint)).withBroadcastSet(StatisticsHelper.summarizer(this.in.getDataSet().map(new ToVectorWithReservedCols(this.selectedColIdxNew, this.labelIdxNew)), true), "summarizer").withBroadcastSet(this.vectorSizes, "vectorSizes").withBroadcastSet(this.constraintDataSet, "constraint").setParallelism(1);
        if (getWithViz().booleanValue()) {
            writeVizData(parallelism, getLinearModelType(), this.selectedCols, getVizDataWriter());
        }
        setOutput((DataSet<Row>) parallelism.flatMap(new BuildModel(this.selectedCol, this.selectedCols, getLinearModelType())).setParallelism(1), new SelectorModelDataConverter().getModelSchema());
        Table[] tableArr = new Table[2];
        tableArr[0] = DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) (this.labels != null ? parallelism.flatMap(new BuildLinearModel(getLinearModelType(), this.selectedCols, TableUtil.findColTypes(batchOperatorArr[0].getSchema(), this.selectedCols), getLabelCol(), findColTypeWithAssertAndHint, str, this.inScorcard)).withBroadcastSet(this.labels, BaseFmTrainBatchOp.LABEL_VALUES) : parallelism.flatMap(new BuildLinearModel(getLinearModelType(), this.selectedCols, TableUtil.findColTypes(batchOperatorArr[0].getSchema(), this.selectedCols), getLabelCol(), findColTypeWithAssertAndHint, str, this.inScorcard))), new LinearModelDataConverter(findColTypeWithAssertAndHint).getModelSchema());
        tableArr[1] = DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) parallelism, new String[]{"result"}, (TypeInformation<?>[]) new TypeInformation[]{Types.STRING});
        setSideOutputTables(tableArr);
        return this;
    }

    public DataSet<SelectorResult> getStepWiseSummary() {
        return getSideOutput(1).getDataSet().map(new ToClassificationSelectorResult(this.linearModelType, getSelectedCols()));
    }

    private void standardConstraint(BatchOperator batchOperator) {
        this.hasConstraint = true;
        if (batchOperator != null) {
            this.constraintDataSet = batchOperator.getDataSet();
        } else {
            this.constraintDataSet = MLEnvironmentFactory.get(this.in.getMLEnvironmentId()).getExecutionEnvironment().fromElements(new Row[]{new Row(0)});
            this.hasConstraint = false;
        }
    }

    private void standardLabel(String str) {
        String labelCol = getLabelCol();
        if (getLinearModelType() == HasConstrainedLinearModelType.LinearModelType.LR || this.inScorcard) {
            Tuple2<BatchOperator, DataSet<Object>> transformLrLabel = ModelSummaryHelper.transformLrLabel(this.in, labelCol, str, getMLEnvironmentId());
            this.in = (BatchOperator) transformLrLabel.f0;
            this.labels = (DataSet) transformLrLabel.f1;
        }
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [int[], java.lang.Object[]] */
    private void calcVectorSizes(int[] iArr, boolean z) {
        if (z) {
            this.vectorSizes = this.in.getDataSet().mapPartition(new CalcVectorSize(iArr)).reduce(new ReduceFunction<int[]>() { // from class: com.alibaba.alink.operator.common.finance.stepwiseSelector.BaseStepWiseSelectorBatchOp.1
                private static final long serialVersionUID = -3014179424640804678L;

                public int[] reduce(int[] iArr2, int[] iArr3) {
                    int[] iArr4 = new int[iArr2.length];
                    for (int i = 0; i < iArr2.length; i++) {
                        iArr4[i] = Math.max(iArr2[i], iArr3[i]);
                    }
                    return iArr4;
                }
            });
            this.hasVectorSizes = true;
        } else {
            this.vectorSizes = MLEnvironmentFactory.get(this.in.getMLEnvironmentId()).getExecutionEnvironment().fromElements((Object[]) new int[]{new int[0]});
            this.hasVectorSizes = false;
        }
    }

    private void transformToVector(int i, TypeInformation typeInformation, HasConstrainedLinearModelType.LinearModelType linearModelType) {
        if (getParams().contains(BaseStepwiseSelectorParams.SELECTED_COL)) {
            this.selectedCol = getSelectedCol();
            if (this.selectedCol != null && !this.selectedCol.isEmpty()) {
                this.selectColNew = this.selectedCol;
                this.selectedColIdxNew = TableUtil.findColIndexWithAssertAndHint(this.in.getColNames(), this.selectedCol);
            }
            calcVectorSizes(null, false);
            this.labelColNew = getLabelCol();
            this.labelIdxNew = TableUtil.findColIndexWithAssertAndHint(this.in.getColNames(), getLabelCol());
        }
        if (getParams().contains(BaseStepwiseSelectorParams.SELECTED_COLS)) {
            this.selectedCols = getSelectedCols();
            int[] findColIndicesWithAssertAndHint = TableUtil.findColIndicesWithAssertAndHint(this.in.getColNames(), this.selectedCols);
            calcVectorSizes(findColIndicesWithAssertAndHint, true);
            TypeInformation[] typeInformationArr = new TypeInformation[2];
            typeInformationArr[0] = AlinkTypes.VECTOR;
            if (linearModelType == HasConstrainedLinearModelType.LinearModelType.LR || this.inScorcard) {
                typeInformationArr[1] = Types.DOUBLE;
            } else {
                typeInformationArr[1] = typeInformation;
            }
            this.in = BatchOperator.fromTable(DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) this.in.getDataSet().map(new VectorAssembler(findColIndicesWithAssertAndHint, i)).withBroadcastSet(this.vectorSizes, "vectorSizes"), new String[]{INNER_VECTOR_COL, INNER_LABLE_COL}, (TypeInformation<?>[]) typeInformationArr));
            this.selectColNew = INNER_VECTOR_COL;
            this.labelColNew = INNER_LABLE_COL;
            this.selectedColIdxNew = 0;
            this.labelIdxNew = 1;
        }
    }

    private static void writeVizData(DataSet<Row> dataSet, HasConstrainedLinearModelType.LinearModelType linearModelType, String[] strArr, VizDataWriterInterface vizDataWriterInterface) {
        DataSetUtil.linkDummySink(dataSet.flatMap(new ProcViz(linearModelType, strArr, vizDataWriterInterface)).setParallelism(1).name("WriteStepWiseViz"));
    }

    public static String[] getSCurSelectedCols(String[] strArr, int[] iArr) {
        if (strArr == null || strArr.length == 0) {
            return null;
        }
        String[] strArr2 = new String[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            strArr2[i] = strArr[iArr[i]];
        }
        return strArr2;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static List<Integer> getOne(int i, int[] iArr) {
        ArrayList arrayList = new ArrayList();
        if (iArr == null) {
            arrayList.add(Integer.valueOf(i));
        } else {
            for (int i2 = iArr[i]; i2 < iArr[i + 1]; i2++) {
                arrayList.add(Integer.valueOf(i2));
            }
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Tuple2<Double, Double> getForwardValue(ModelSummary modelSummary, ModelSummary modelSummary2, StepWiseType stepWiseType) {
        switch (stepWiseType) {
            case fTest:
                LinearRegressionSummary linearRegressionSummary = (LinearRegressionSummary) modelSummary;
                return Tuple2.of(Double.valueOf(linearRegressionSummary.fValue), Double.valueOf(linearRegressionSummary.pValue));
            case scoreTest:
                LogistRegressionSummary logistRegressionSummary = (LogistRegressionSummary) modelSummary;
                return Tuple2.of(Double.valueOf(logistRegressionSummary.scoreChiSquareValue), Double.valueOf(logistRegressionSummary.scorePValue));
            case marginalContribution:
                double d = 0.0d;
                if (modelSummary2 != null) {
                    d = modelSummary2.loss;
                }
                double d2 = (modelSummary.loss - d) / modelSummary.count;
                return Tuple2.of(Double.valueOf(d2), Double.valueOf(d2));
            default:
                throw new RuntimeException("It is not support.");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double[] getBackwardValues(ModelSummary modelSummary, BaseVectorSummarizer baseVectorSummarizer, StepWiseType stepWiseType, int[] iArr, List<Integer> list) {
        switch (stepWiseType) {
            case fTest:
                return ((LinearRegressionSummary) modelSummary).tPVaues;
            case scoreTest:
                LogistRegressionSummary logistRegressionSummary = (LogistRegressionSummary) modelSummary;
                return Arrays.copyOfRange(logistRegressionSummary.waldPValues, 1, logistRegressionSummary.waldPValues.length);
            case marginalContribution:
                int size = list.size();
                double[] dArr = new double[size];
                for (int i = 0; i < size; i++) {
                    ArrayList arrayList = new ArrayList();
                    for (int i2 = 0; i2 < i; i2++) {
                        arrayList.addAll(getOne(list.get(i2).intValue(), iArr));
                    }
                    for (int i3 = i + 1; i3 < size; i3++) {
                        arrayList.addAll(getOne(list.get(i3).intValue(), iArr));
                    }
                    dArr[i] = (LocalLinearModel.calcModelSummary(Tuple4.of(modelSummary.beta, modelSummary.gradient, modelSummary.hessian, Double.valueOf(modelSummary.loss)), baseVectorSummarizer, modelSummary instanceof LogistRegressionSummary ? LinearModelType.LR : LinearModelType.LinearReg, getIndicesFromList(arrayList)).loss - modelSummary.loss) / modelSummary.count;
                }
                return dArr;
            default:
                throw new RuntimeException("It is not support.");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int[] getIndicesFromList(List<Integer> list) {
        int[] iArr = new int[list.size()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = list.get(i).intValue();
        }
        return iArr;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int argmax(double[] dArr) {
        if (dArr == null && dArr.length == 0) {
            throw new RuntimeException("max values is null.");
        }
        int i = 0;
        double d = dArr[0];
        for (int i2 = 1; i2 < dArr.length; i2++) {
            if (d < dArr[i2]) {
                d = dArr[i2];
                i = i2;
            }
        }
        return i;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Boolean isIdxExist(int[] iArr, int i) {
        for (int i2 : iArr) {
            if (i == i2) {
                return true;
            }
        }
        return false;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ BaseStepWiseSelectorBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
