package com.alibaba.alink.operator.common.linear;

import com.alibaba.alink.common.MLEnvironment;
import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.FeatureColsVectorColMutexRule;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.common.viz.AlinkViz;
import com.alibaba.alink.common.viz.VizDataWriterForModelInfo;
import com.alibaba.alink.common.viz.VizDataWriterInterface;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.finance.ScorecardTrainBatchOp;
import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper;
import com.alibaba.alink.operator.common.evaluation.EvaluationUtil;
import com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp;
import com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp;
import com.alibaba.alink.operator.common.linear.unarylossfunc.LogLossFunc;
import com.alibaba.alink.operator.common.linear.unarylossfunc.SquareLossFunc;
import com.alibaba.alink.operator.common.optim.FeatureConstraint;
import com.alibaba.alink.operator.common.optim.Lbfgs;
import com.alibaba.alink.operator.common.optim.Newton;
import com.alibaba.alink.operator.common.optim.activeSet.ConstraintObjFunc;
import com.alibaba.alink.operator.common.optim.activeSet.ConstraintVariable;
import com.alibaba.alink.operator.common.optim.activeSet.Sqp;
import com.alibaba.alink.operator.common.optim.barrierIcq.LogBarrier;
import com.alibaba.alink.operator.common.optim.local.ConstrainedLocalOptimizer;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable;
import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary;
import com.alibaba.alink.operator.common.statistics.basicstatistic.SparseVectorSummary;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.common.tree.Preprocessing;
import com.alibaba.alink.params.finance.ConstrainedLinearModelParams;
import com.alibaba.alink.params.finance.ConstrainedLogisticRegressionTrainParams;
import com.alibaba.alink.params.finance.HasConstrainedOptimizationMethod;
import com.alibaba.alink.params.shared.linear.LinearTrainParams;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.linear.LinearConstraint;
import org.apache.commons.math3.optim.linear.LinearConstraintSet;
import org.apache.commons.math3.optim.linear.LinearObjectiveFunction;
import org.apache.commons.math3.optim.linear.Relationship;
import org.apache.commons.math3.optim.linear.SimplexSolver;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.flink.api.common.functions.AbstractRichFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@FeatureColsVectorColMutexRule
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "featureCols", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES}), @ParamSelectColumnSpec(name = "labelCol"), @ParamSelectColumnSpec(name = "weightCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES})})
/* loaded from: input_file:com/alibaba/alink/operator/common/linear/BaseConstrainedLinearModelTrainBatchOp.class */
public abstract class BaseConstrainedLinearModelTrainBatchOp<T extends BaseConstrainedLinearModelTrainBatchOp<T>> extends BatchOperator<T> implements AlinkViz<T> {
    private static final long serialVersionUID = 1180583968098354917L;
    private String modelName;
    private LinearModelType linearModelType;
    private static final int NUM_FEATURE_THRESHOLD = 10000;
    private static final String META = "meta";
    private static final String MEAN_VAR = "meanVar";
    private static final String VECTOR_SIZE = "vectorSize";
    private static final String LABEL_VALUES = "labelValues";

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/linear/BaseConstrainedLinearModelTrainBatchOp$BuildLabels.class */
    public static class BuildLabels implements FlatMapFunction<Tuple3<DenseVector[], Object[], Integer[]>, Object[]> {
        private boolean isRegProc;
        private String positiveLabel;
        private static final long serialVersionUID = 5375954526931728363L;

        BuildLabels(boolean z, String str) {
            this.isRegProc = z;
            this.positiveLabel = str;
        }

        public void flatMap(Tuple3<DenseVector[], Object[], Integer[]> tuple3, Collector<Object[]> collector) throws Exception {
            if (this.isRegProc) {
                collector.collect(tuple3.f1);
            } else {
                Preconditions.checkState(((Object[]) tuple3.f1).length == 2, "labels count should be 2 in in classification algo.");
                collector.collect(BaseConstrainedLinearModelTrainBatchOp.orderLabels((Object[]) tuple3.f1, this.positiveLabel));
            }
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Tuple3<DenseVector[], Object[], Integer[]>) obj, (Collector<Object[]>) collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/alibaba/alink/operator/common/linear/BaseConstrainedLinearModelTrainBatchOp$ConstrainedOptMethod.class */
    public enum ConstrainedOptMethod {
        SQP,
        BARRIER,
        LBFGS,
        NEWTON
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/linear/BaseConstrainedLinearModelTrainBatchOp$CreateMeta.class */
    public static class CreateMeta implements MapPartitionFunction<Object, Params> {
        private static final long serialVersionUID = -7148219424266582224L;
        private String modelName;
        private LinearModelType modelType;
        private boolean hasInterceptItem;
        private String vectorColName;
        private String labelName;
        private boolean calcLabel;
        private String positiveLabel;

        public CreateMeta(String str, LinearModelType linearModelType, Params params, boolean z, String str2) {
            this.modelName = str;
            this.modelType = linearModelType;
            this.hasInterceptItem = ((Boolean) params.get(LinearTrainParams.WITH_INTERCEPT)).booleanValue();
            this.vectorColName = (String) params.get(LinearTrainParams.VECTOR_COL);
            this.labelName = (String) params.get(LinearTrainParams.LABEL_COL);
            this.calcLabel = z;
            this.positiveLabel = str2;
        }

        public void mapPartition(Iterable<Object> iterable, Collector<Params> collector) throws Exception {
            Object[] objArr = null;
            if (this.calcLabel) {
                objArr = BaseConstrainedLinearModelTrainBatchOp.orderLabels(iterable, this.positiveLabel);
            }
            Params params = new Params();
            params.set((ParamInfo<ParamInfo<String>>) ModelParamName.MODEL_NAME, (ParamInfo<String>) this.modelName);
            params.set((ParamInfo<ParamInfo<LinearModelType>>) ModelParamName.LINEAR_MODEL_TYPE, (ParamInfo<LinearModelType>) this.modelType);
            params.set((ParamInfo<ParamInfo<Object[]>>) ModelParamName.LABEL_VALUES, (ParamInfo<Object[]>) objArr);
            params.set((ParamInfo<ParamInfo<Boolean>>) ModelParamName.HAS_INTERCEPT_ITEM, (ParamInfo<Boolean>) Boolean.valueOf(this.hasInterceptItem));
            params.set((ParamInfo<ParamInfo<String>>) ModelParamName.VECTOR_COL_NAME, (ParamInfo<String>) this.vectorColName);
            params.set((ParamInfo<ParamInfo<String>>) LinearTrainParams.LABEL_COL, (ParamInfo<String>) this.labelName);
            collector.collect(params);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/linear/BaseConstrainedLinearModelTrainBatchOp$DimTrans.class */
    private static class DimTrans extends AbstractRichFunction implements MapFunction<Integer, Integer> {
        private static final long serialVersionUID = 1997987979691400583L;
        private boolean hasInterceptItem;
        private Integer featureDim = null;

        public DimTrans(boolean z) {
            this.hasInterceptItem = z;
        }

        public void open(Configuration configuration) throws Exception {
            this.featureDim = (Integer) getRuntimeContext().getBroadcastVariable("vectorSize").get(0);
        }

        public Integer map(Integer num) throws Exception {
            return Integer.valueOf(this.featureDim.intValue() + (this.hasInterceptItem ? 1 : 0));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/linear/BaseConstrainedLinearModelTrainBatchOp$GenerateConstraint.class */
    public static class GenerateConstraint implements MapFunction<Row, Row> {
        private static final long serialVersionUID = -6999309059934707482L;

        private GenerateConstraint() {
        }

        public Row map(Row row) {
            return Row.of(new Object[]{row.getField(0) instanceof FeatureConstraint ? (FeatureConstraint) row.getField(0) : FeatureConstraint.fromJson((String) row.getField(0))});
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/linear/BaseConstrainedLinearModelTrainBatchOp$GetConstraint.class */
    public static class GetConstraint extends RichMapFunction<OptimObjFunc, OptimObjFunc> {
        private static final long serialVersionUID = -7810872210451727729L;
        private int coefDim;
        private FeatureConstraint constraint;
        private String[] featureColNames;
        private boolean hasInterceptItem;
        private ConstrainedOptMethod method;
        private DenseVector countZero = null;
        private Map<String, Boolean> hasElse;

        GetConstraint(String[] strArr, boolean z, String str, Map<String, Boolean> map) {
            this.featureColNames = strArr;
            this.hasInterceptItem = z;
            this.method = ConstrainedOptMethod.valueOf(str.toUpperCase());
            this.hasElse = map;
        }

        public void open(Configuration configuration) throws Exception {
            this.coefDim = ((Integer) getRuntimeContext().getBroadcastVariable(OptimVariable.coef).get(0)).intValue();
            this.constraint = (FeatureConstraint) ((Row) getRuntimeContext().getBroadcastVariable(ConstraintVariable.constraints).get(0)).getField(0);
            if (this.constraint.fromScorecard()) {
                this.countZero = (DenseVector) getRuntimeContext().getBroadcastVariable("countZero").get(0);
            }
        }

        public OptimObjFunc map(OptimObjFunc optimObjFunc) throws Exception {
            ConstraintObjFunc constraintObjFunc = (ConstraintObjFunc) optimObjFunc;
            if (!ConstrainedOptMethod.LBFGS.equals(this.method) && !ConstrainedOptMethod.NEWTON.equals(this.method)) {
                ConstrainedLocalOptimizer.extractConstraintsForFeatureAndBin(this.constraint, constraintObjFunc, this.featureColNames, this.coefDim, this.hasInterceptItem, this.countZero, this.hasElse);
                if (constraintObjFunc.equalityItem.size() + constraintObjFunc.inequalityItem.size() != 0) {
                    int numCols = constraintObjFunc.equalityConstraint.numRows() != 0 ? constraintObjFunc.equalityConstraint.numCols() : constraintObjFunc.inequalityConstraint.numCols();
                    OptimizationData linearObjectiveFunction = new LinearObjectiveFunction(new double[numCols], Criteria.INVALID_GAIN);
                    ArrayList arrayList = new ArrayList();
                    for (int i = 0; i < constraintObjFunc.equalityItem.size(); i++) {
                        double[] dArr = new double[numCols];
                        System.arraycopy(constraintObjFunc.equalityConstraint.getRow(i), 0, dArr, 0, numCols);
                        arrayList.add(new LinearConstraint(dArr, Relationship.EQ, constraintObjFunc.equalityItem.get(i)));
                    }
                    for (int i2 = 0; i2 < constraintObjFunc.inequalityItem.size(); i2++) {
                        double[] dArr2 = new double[numCols];
                        System.arraycopy(constraintObjFunc.inequalityConstraint.getRow(i2), 0, dArr2, 0, numCols);
                        arrayList.add(new LinearConstraint(dArr2, Relationship.GEQ, constraintObjFunc.inequalityItem.get(i2)));
                    }
                    try {
                        new SimplexSolver().optimize(new OptimizationData[]{linearObjectiveFunction, new LinearConstraintSet(arrayList), GoalType.MINIMIZE});
                    } catch (Exception e) {
                        throw new RuntimeException("infeasible constraint!", e);
                    }
                }
            }
            return constraintObjFunc;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/linear/BaseConstrainedLinearModelTrainBatchOp$Transform.class */
    private static class Transform extends RichMapFunction<Row, Tuple3<Double, Double, Vector>> {
        private static final long serialVersionUID = 3541655329500762922L;
        private String positiveLableValueString;
        private boolean isRegProc;
        private int weightIdx;
        private int vecIdx;
        private int labelIdx;
        private int[] featureIndices;
        private TypeInformation type;

        public Transform(boolean z, int i, int i2, int[] iArr, int i3, String str, TypeInformation typeInformation) {
            this.isRegProc = z;
            this.weightIdx = i;
            this.vecIdx = i2;
            this.featureIndices = iArr;
            this.labelIdx = i3;
            this.positiveLableValueString = str;
            this.type = typeInformation;
        }

        public void open(Configuration configuration) throws Exception {
            if (this.isRegProc) {
                return;
            }
            Object[] orderLabels = BaseConstrainedLinearModelTrainBatchOp.orderLabels(getRuntimeContext().getBroadcastVariable("labelValues"), this.positiveLableValueString);
            if (this.positiveLableValueString == null) {
                throw new RuntimeException("constrained logistic regression must set positive label!");
            }
            EvaluationUtil.ComparableLabel comparableLabel = new EvaluationUtil.ComparableLabel(this.positiveLableValueString, this.type);
            if (!comparableLabel.equals(new EvaluationUtil.ComparableLabel(orderLabels[0].toString(), this.type)) && !comparableLabel.equals(new EvaluationUtil.ComparableLabel(orderLabels[1].toString(), this.type))) {
                throw new RuntimeException("the user defined positive label is not in the data!");
            }
        }

        public Tuple3<Double, Double, Vector> map(Row row) throws Exception {
            Double valueOf = Double.valueOf(this.weightIdx != -1 ? ((Number) row.getField(this.weightIdx)).doubleValue() : 1.0d);
            Double valueOf2 = Double.valueOf(FeatureLabelUtil.getLabelValue(row, this.isRegProc, this.labelIdx, this.positiveLableValueString));
            if (this.featureIndices == null) {
                Vector vector = VectorUtil.getVector(row.getField(this.vecIdx));
                Preconditions.checkState(vector != null, "vector for linear model train is null, please check your input data.");
                return Tuple3.of(valueOf, valueOf2, vector);
            }
            DenseVector denseVector = new DenseVector(this.featureIndices.length);
            for (int i = 0; i < this.featureIndices.length; i++) {
                denseVector.set(i, ((Number) row.getField(this.featureIndices[i])).doubleValue());
            }
            return Tuple3.of(valueOf, valueOf2, denseVector);
        }
    }

    public BaseConstrainedLinearModelTrainBatchOp(Params params, LinearModelType linearModelType, String str) {
        super(params);
        this.modelName = str;
        this.linearModelType = linearModelType;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public T linkFrom(BatchOperator<?>... batchOperatorArr) {
        DataSource map;
        Params params = getParams();
        BatchOperator<?> batchOperator = batchOperatorArr[0];
        DataSet<Row> dataSet = null;
        if (batchOperatorArr.length == 2 && batchOperatorArr[1] != null) {
            dataSet = batchOperatorArr[1].getDataSet();
            params.set((ParamInfo<ParamInfo<Boolean>>) LinearTrainParams.STANDARDIZATION, (ParamInfo<Boolean>) false);
        }
        String str = (String) params.get(ConstrainedLinearModelParams.CONSTRAINT);
        if ("".equals(str)) {
            map = dataSet != null ? batchOperatorArr[1].getDataSet().map(new GenerateConstraint()) : MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromElements(new Row[]{Row.of(new Object[]{new FeatureConstraint()})});
        } else {
            map = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromElements(new Row[]{Row.of(new Object[]{FeatureConstraint.fromJson(str)})});
            params.set((ParamInfo<ParamInfo<Boolean>>) LinearTrainParams.STANDARDIZATION, (ParamInfo<Boolean>) false);
        }
        String str2 = null;
        boolean z = LinearModelType.LR == this.linearModelType || ((Boolean) getParams().get(ScorecardTrainBatchOp.IN_SCORECARD)).booleanValue();
        if (z) {
            str2 = (String) params.get(ConstrainedLogisticRegressionTrainParams.POS_LABEL_VAL_STR);
        }
        if (!params.contains(HasConstrainedOptimizationMethod.CONST_OPTIM_METHOD)) {
            params.set((ParamInfo<ParamInfo<HasConstrainedOptimizationMethod.ConstOptimMethod>>) HasConstrainedOptimizationMethod.CONST_OPTIM_METHOD, (ParamInfo<HasConstrainedOptimizationMethod.ConstOptimMethod>) HasConstrainedOptimizationMethod.ConstOptimMethod.SQP);
        }
        String upperCase = ((HasConstrainedOptimizationMethod.ConstOptimMethod) params.get(HasConstrainedOptimizationMethod.CONST_OPTIM_METHOD)).toString().toUpperCase();
        boolean z2 = this.linearModelType == LinearModelType.LinearReg;
        boolean booleanValue = ((Boolean) params.get(LinearTrainParams.STANDARDIZATION)).booleanValue();
        Tuple2<DataSet<Object>, TypeInformation> labelInfo = getLabelInfo(batchOperator, params, !z);
        DataSet<Tuple3<Double, Object, Vector>> transform = BaseLinearModelTrainBatchOp.transform(batchOperator, params, z2, booleanValue);
        DataSet<Tuple3<DenseVector[], Object[], Integer[]>> utilInfo = BaseLinearModelTrainBatchOp.getUtilInfo(transform, booleanValue, z2);
        MapOperator map2 = utilInfo.map(new MapFunction<Tuple3<DenseVector[], Object[], Integer[]>, DenseVector[]>() { // from class: com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp.1
            private static final long serialVersionUID = 7127767376687624403L;

            public DenseVector[] map(Tuple3<DenseVector[], Object[], Integer[]> tuple3) throws Exception {
                return (DenseVector[]) tuple3.f0;
            }
        });
        MapOperator map3 = utilInfo.map(new MapFunction<Tuple3<DenseVector[], Object[], Integer[]>, Integer>() { // from class: com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp.2
            private static final long serialVersionUID = 2773811388068064638L;

            public Integer map(Tuple3<DenseVector[], Object[], Integer[]> tuple3) throws Exception {
                return ((Integer[]) tuple3.f2)[0];
            }
        });
        DataSet<Tuple3<Double, Double, Vector>> preProcess = BaseLinearModelTrainBatchOp.preProcess(transform, params, z2, map2, utilInfo.flatMap(new BuildLabels(z2, str2)), map3);
        Operator parallelism = optimize(params, getOptParam(map, params, map3, this.linearModelType, MLEnvironmentFactory.get(getMLEnvironmentId()), upperCase, StatisticsHelper.summary(preProcess.map(new MapFunction<Tuple3<Double, Double, Vector>, Vector>() { // from class: com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp.4
            private static final long serialVersionUID = 6207307350053531656L;

            public Vector map(Tuple3<Double, Double, Vector> tuple3) throws Exception {
                return (Vector) tuple3.f2;
            }
        }).withForwardedFields(new String[0])).map(new MapFunction<BaseVectorSummary, DenseVector>() { // from class: com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp.3
            private static final long serialVersionUID = 2322849507320367330L;

            public DenseVector map(BaseVectorSummary baseVectorSummary) throws Exception {
                return baseVectorSummary instanceof SparseVectorSummary ? (DenseVector) ((SparseVectorSummary) baseVectorSummary).numNonZero() : new DenseVector(0);
            }
        })), preProcess, this.modelName, upperCase).mapPartition(new BaseLinearModelTrainBatchOp.BuildModelFromCoefs((TypeInformation) labelInfo.f1, (String[]) params.get(LinearTrainParams.FEATURE_COLS), ((Boolean) params.get(LinearTrainParams.STANDARDIZATION)).booleanValue(), ((Boolean) params.get(LinearTrainParams.WITH_INTERCEPT)).booleanValue(), BaseLinearModelTrainBatchOp.getFeatureTypes(batchOperator, (String[]) params.get(LinearTrainParams.FEATURE_COLS)))).withBroadcastSet(((DataSet) labelInfo.f0).mapPartition(new CreateMeta(this.modelName, this.linearModelType, params, z, str2)).setParallelism(1), META).withBroadcastSet(map2, MEAN_VAR).setParallelism(1);
        setOutput((DataSet<Row>) parallelism, new LinearModelDataConverter((TypeInformation) labelInfo.f1).getModelSchema());
        writeVizData(parallelism, map3);
        return this;
    }

    protected static DataSet<Tuple3<Double, Double, Vector>> transform(BatchOperator batchOperator, Params params, DataSet<Object> dataSet, boolean z, String str, TypeInformation typeInformation) {
        String[] strArr = (String[]) params.get(LinearTrainParams.FEATURE_COLS);
        String str2 = (String) params.get(LinearTrainParams.LABEL_COL);
        String str3 = (String) params.get(LinearTrainParams.WEIGHT_COL);
        String str4 = (String) params.get(LinearTrainParams.VECTOR_COL);
        TableSchema schema = batchOperator.getSchema();
        if (null == strArr && null == str4) {
            strArr = TableUtil.getNumericCols(schema, new String[]{str2});
            params.set((ParamInfo<ParamInfo<String[]>>) LinearTrainParams.FEATURE_COLS, (ParamInfo<String[]>) strArr);
        }
        int[] iArr = null;
        int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(schema.getFieldNames(), str2);
        if (strArr != null) {
            iArr = new int[strArr.length];
            for (int i = 0; i < strArr.length; i++) {
                int findColIndexWithAssertAndHint2 = TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), strArr[i]);
                iArr[i] = findColIndexWithAssertAndHint2;
                TypeInformation typeInformation2 = batchOperator.getSchema().getFieldTypes()[findColIndexWithAssertAndHint2];
                Preconditions.checkState(TableUtil.isSupportedNumericType(typeInformation2), "linear algorithm only support numerical data type. type is : " + typeInformation2);
            }
        }
        return batchOperator.getDataSet().map(new Transform(z, str3 != null ? TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str3) : -1, str4 != null ? TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str4) : -1, iArr, findColIndexWithAssertAndHint, str, typeInformation)).withBroadcastSet(dataSet, "labelValues");
    }

    protected static Object[] orderLabels(Iterable<Object> iterable, String str) {
        ArrayList arrayList = new ArrayList();
        Iterator<Object> it = iterable.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next());
        }
        Object[] array = arrayList.toArray(new Object[0]);
        Preconditions.checkState(array.length >= 2, "labels count should be more than 2 in classification algo.");
        if (array[1].toString().equals(str)) {
            Object obj = array[0];
            array[0] = array[1];
            array[1] = obj;
        }
        return array;
    }

    protected static Object[] orderLabels(Object[] objArr, String str) {
        Preconditions.checkState(objArr.length >= 2, "labels count should be more than 2 in classification algo.");
        if (objArr[1].toString().equals(str)) {
            Object obj = objArr[0];
            objArr[0] = objArr[1];
            objArr[1] = obj;
        }
        return objArr;
    }

    private static Tuple2<DataSet<OptimObjFunc>, DataSet<Integer>> getOptParam(DataSet<Row> dataSet, Params params, DataSet<Integer> dataSet2, LinearModelType linearModelType, MLEnvironment mLEnvironment, String str, DataSet<DenseVector> dataSet3) {
        DataSet<Integer> fromElements;
        boolean booleanValue = ((Boolean) params.get(LinearTrainParams.WITH_INTERCEPT)).booleanValue();
        String[] strArr = (String[]) params.get(LinearTrainParams.FEATURE_COLS);
        String str2 = (String) params.get(LinearTrainParams.VECTOR_COL);
        if ("".equals(str2)) {
            str2 = null;
        }
        if (ArrayUtils.isEmpty(strArr)) {
            strArr = null;
        }
        if (str2 == null || str2.length() == 0) {
            ExecutionEnvironment executionEnvironment = mLEnvironment.getExecutionEnvironment();
            Integer[] numArr = new Integer[1];
            numArr[0] = Integer.valueOf(strArr.length + (booleanValue ? 1 : 0));
            fromElements = executionEnvironment.fromElements(numArr);
        } else {
            fromElements = dataSet2;
        }
        return Tuple2.of(mLEnvironment.getExecutionEnvironment().fromElements(new OptimObjFunc[]{getObjFunction(linearModelType, params)}).map(new GetConstraint(strArr, booleanValue, str, (Map) params.get(ScorecardTrainBatchOp.WITH_ELSE))).withBroadcastSet(fromElements, OptimVariable.coef).withBroadcastSet(dataSet, ConstraintVariable.constraints).withBroadcastSet(dataSet3, "countZero"), fromElements);
    }

    public static DataSet<Tuple2<DenseVector, double[]>> optimize(Params params, Tuple2<DataSet<OptimObjFunc>, DataSet<Integer>> tuple2, DataSet<Tuple3<Double, Double, Vector>> dataSet, String str, String str2) {
        DataSet dataSet2 = (DataSet) tuple2.f0;
        DataSet dataSet3 = (DataSet) tuple2.f1;
        if (!params.contains(HasConstrainedOptimizationMethod.CONST_OPTIM_METHOD)) {
            return new Sqp(dataSet2, dataSet, dataSet3, params).optimize();
        }
        switch (ConstrainedOptMethod.valueOf(str2)) {
            case SQP:
                return new Sqp(dataSet2, dataSet, dataSet3, params).optimize();
            case BARRIER:
                return new LogBarrier(dataSet2, dataSet, dataSet3, params).optimize();
            case LBFGS:
                return new Lbfgs(dataSet2, dataSet, dataSet3, params).optimize();
            case NEWTON:
                return new Newton(dataSet2, dataSet, dataSet3, params).optimize();
            default:
                throw new RuntimeException("do not support the " + str2 + " method!");
        }
    }

    public static OptimObjFunc getObjFunction(LinearModelType linearModelType, Params params) {
        ConstraintObjFunc constraintObjFunc;
        if (linearModelType == LinearModelType.LinearReg) {
            constraintObjFunc = new ConstraintObjFunc(new SquareLossFunc(), params);
        } else {
            if (linearModelType != LinearModelType.LR) {
                throw new RuntimeException("Not implemented yet!");
            }
            constraintObjFunc = new ConstraintObjFunc(new LogLossFunc(), params);
        }
        return constraintObjFunc;
    }

    private void writeVizData(DataSet<Row> dataSet, DataSet<Integer> dataSet2) {
        VizDataWriterInterface vizDataWriter = getVizDataWriter();
        if (vizDataWriter == null) {
            return;
        }
        VizDataWriterForModelInfo.writeModelInfo(vizDataWriter, getClass().getSimpleName(), getOutputTable().getSchema(), dataSet.mapPartition(new RichMapPartitionFunction<Row, Row>() { // from class: com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp.5
            private static final long serialVersionUID = -7146244281747193903L;

            public void mapPartition(Iterable<Row> iterable, Collector<Row> collector) {
                if (((Integer) getRuntimeContext().getBroadcastVariable("vectorSize").get(0)).intValue() > 10000) {
                    collector.collect(Row.of(new Object[]{"Not support models with #features > 10000"}));
                } else {
                    collector.getClass();
                    iterable.forEach((v1) -> {
                        r1.collect(v1);
                    });
                }
            }
        }).withBroadcastSet(dataSet2, "vectorSize").setParallelism(1), getParams());
    }

    protected static Tuple3<DataSet<Integer>, DataSet<DenseVector[]>, DataSet<DenseVector>> getStatInfo(DataSet<Tuple3<Double, Double, Vector>> dataSet, boolean z) {
        DataSet<BaseVectorSummary> summary = StatisticsHelper.summary(dataSet.map(new MapFunction<Tuple3<Double, Double, Vector>, Vector>() { // from class: com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp.6
            private static final long serialVersionUID = 6207307350053531656L;

            public Vector map(Tuple3<Double, Double, Vector> tuple3) throws Exception {
                return (Vector) tuple3.f2;
            }
        }).withForwardedFields(new String[0]));
        MapOperator map = summary.map(new MapFunction<BaseVectorSummary, DenseVector>() { // from class: com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp.7
            private static final long serialVersionUID = 2322849507320367330L;

            public DenseVector map(BaseVectorSummary baseVectorSummary) throws Exception {
                return baseVectorSummary instanceof SparseVectorSummary ? (DenseVector) ((SparseVectorSummary) baseVectorSummary).numNonZero() : new DenseVector(0);
            }
        });
        if (z) {
            return Tuple3.of(summary.map(new MapFunction<BaseVectorSummary, Integer>() { // from class: com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp.8
                private static final long serialVersionUID = -8051245706564042978L;

                public Integer map(BaseVectorSummary baseVectorSummary) throws Exception {
                    return Integer.valueOf(baseVectorSummary.vectorSize());
                }
            }), summary.map(new MapFunction<BaseVectorSummary, DenseVector[]>() { // from class: com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp.9
                private static final long serialVersionUID = -6992060467629008691L;

                public DenseVector[] map(BaseVectorSummary baseVectorSummary) {
                    if (!(baseVectorSummary instanceof SparseVectorSummary)) {
                        return new DenseVector[]{(DenseVector) baseVectorSummary.mean(), (DenseVector) baseVectorSummary.standardDeviation()};
                    }
                    DenseVector denseVector = ((SparseVector) baseVectorSummary.max()).toDenseVector();
                    DenseVector denseVector2 = ((SparseVector) baseVectorSummary.min()).toDenseVector();
                    for (int i = 0; i < denseVector.size(); i++) {
                        denseVector.set(i, Math.max(Math.abs(denseVector.get(i)), Math.abs(denseVector2.get(i))));
                        denseVector2.set(i, Criteria.INVALID_GAIN);
                    }
                    return new DenseVector[]{denseVector2, denseVector};
                }
            }), map);
        }
        GroupReduceOperator reduceGroup = dataSet.mapPartition(new MapPartitionFunction<Tuple3<Double, Double, Vector>, Integer>() { // from class: com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp.11
            private static final long serialVersionUID = 3426157421982727224L;

            public void mapPartition(Iterable<Tuple3<Double, Double, Vector>> iterable, Collector<Integer> collector) throws Exception {
                int i = -1;
                Iterator<Tuple3<Double, Double, Vector>> it = iterable.iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    Tuple3<Double, Double, Vector> next = it.next();
                    if (next.f2 instanceof DenseVector) {
                        i = ((DenseVector) next.f2).getData().length;
                        break;
                    }
                    for (int i2 : ((SparseVector) next.f2).getIndices()) {
                        i = Math.max(i, i2 + 1);
                    }
                    i = Math.max(i, ((Vector) next.f2).size());
                }
                collector.collect(Integer.valueOf(i));
            }
        }).reduceGroup(new GroupReduceFunction<Integer, Integer>() { // from class: com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp.10
            private static final long serialVersionUID = 2752381384411882555L;

            public void reduce(Iterable<Integer> iterable, Collector<Integer> collector) {
                int i = -1;
                Iterator<Integer> it = iterable.iterator();
                while (it.hasNext()) {
                    i = Math.max(i, it.next().intValue());
                }
                collector.collect(Integer.valueOf(i));
            }
        });
        return Tuple3.of(reduceGroup, reduceGroup.map(new MapFunction<Integer, DenseVector[]>() { // from class: com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp.12
            private static final long serialVersionUID = 5448632685946933829L;

            public DenseVector[] map(Integer num) {
                return new DenseVector[]{new DenseVector(0), new DenseVector(0)};
            }
        }), map);
    }

    protected static DataSet<Tuple3<Double, Double, Vector>> preProcess(DataSet<Tuple3<Double, Double, Vector>> dataSet, Params params, DataSet<DenseVector[]> dataSet2) {
        final boolean booleanValue = ((Boolean) params.get(LinearTrainParams.WITH_INTERCEPT)).booleanValue();
        final boolean booleanValue2 = ((Boolean) params.get(LinearTrainParams.STANDARDIZATION)).booleanValue();
        return dataSet.map(new RichMapFunction<Tuple3<Double, Double, Vector>, Tuple3<Double, Double, Vector>>() { // from class: com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp.13
            private static final long serialVersionUID = -5342628140781184056L;
            private DenseVector[] meanVar;

            public void open(Configuration configuration) throws Exception {
                this.meanVar = (DenseVector[]) getRuntimeContext().getBroadcastVariable(BaseConstrainedLinearModelTrainBatchOp.MEAN_VAR).get(0);
                BaseConstrainedLinearModelTrainBatchOp.modifyMeanVar(booleanValue2, this.meanVar);
            }

            public Tuple3<Double, Double, Vector> map(Tuple3<Double, Double, Vector> tuple3) throws Exception {
                DenseVector denseVector;
                Vector vector = (Vector) tuple3.f2;
                if (!(vector instanceof DenseVector)) {
                    SparseVector sparseVector = (SparseVector) vector;
                    if (booleanValue2) {
                        if (booleanValue) {
                            int[] indices = sparseVector.getIndices();
                            double[] values = sparseVector.getValues();
                            for (int i = 0; i < indices.length; i++) {
                                values[i] = (values[i] - this.meanVar[0].get(indices[i])) / this.meanVar[1].get(indices[i]);
                            }
                            sparseVector = sparseVector.prefix(1.0d);
                        } else {
                            int[] indices2 = sparseVector.getIndices();
                            double[] values2 = sparseVector.getValues();
                            for (int i2 = 0; i2 < indices2.length; i2++) {
                                values2[i2] = values2[i2] / this.meanVar[1].get(indices2[i2]);
                            }
                        }
                    } else if (booleanValue) {
                        sparseVector = sparseVector.prefix(1.0d);
                    }
                    return Tuple3.of(tuple3.f0, tuple3.f1, sparseVector);
                }
                if (booleanValue2) {
                    if (booleanValue) {
                        denseVector = new DenseVector(vector.size() + 1);
                        denseVector.set(0, 1.0d);
                        for (int i3 = 0; i3 < vector.size(); i3++) {
                            denseVector.set(i3 + 1, (vector.get(i3) - this.meanVar[0].get(i3)) / this.meanVar[1].get(i3));
                        }
                    } else {
                        denseVector = (DenseVector) vector;
                        for (int i4 = 0; i4 < vector.size(); i4++) {
                            denseVector.set(i4, vector.get(i4) / this.meanVar[1].get(i4));
                        }
                    }
                } else if (booleanValue) {
                    denseVector = new DenseVector(vector.size() + 1);
                    denseVector.set(0, 1.0d);
                    for (int i5 = 0; i5 < vector.size(); i5++) {
                        denseVector.set(i5 + 1, vector.get(i5));
                    }
                } else {
                    denseVector = (DenseVector) vector;
                }
                return Tuple3.of(tuple3.f0, tuple3.f1, denseVector);
            }
        }).withBroadcastSet(dataSet2, MEAN_VAR);
    }

    protected static Tuple2<DataSet<Object>, TypeInformation> getLabelInfo(BatchOperator batchOperator, Params params, boolean z) {
        TypeInformation<?> typeInformation;
        DataSource flatMap;
        String str = (String) params.get(LinearTrainParams.LABEL_COL);
        if (z) {
            typeInformation = Types.DOUBLE;
            flatMap = MLEnvironmentFactory.get(batchOperator.getMLEnvironmentId()).getExecutionEnvironment().fromElements(new Object[]{new Object()});
        } else {
            typeInformation = batchOperator.getColTypes()[TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str)];
            flatMap = Preprocessing.distinctLabels(Preprocessing.select(batchOperator, str).getDataSet().map(new MapFunction<Row, Object>() { // from class: com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp.15
                private static final long serialVersionUID = -419245917074561046L;

                public Object map(Row row) throws Exception {
                    return row.getField(0);
                }
            })).flatMap(new FlatMapFunction<Object[], Object>() { // from class: com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp.14
                private static final long serialVersionUID = -5089566319196319692L;

                public void flatMap(Object[] objArr, Collector<Object> collector) throws Exception {
                    for (Object obj : objArr) {
                        collector.collect(obj);
                    }
                }

                public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                    flatMap((Object[]) obj, (Collector<Object>) collector);
                }
            });
        }
        return Tuple2.of(flatMap, typeInformation);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void modifyMeanVar(boolean z, DenseVector[] denseVectorArr) {
        if (z) {
            for (int i = 0; i < denseVectorArr[1].size(); i++) {
                if (denseVectorArr[1].get(i) == Criteria.INVALID_GAIN) {
                    denseVectorArr[1].set(i, 1.0d);
                    denseVectorArr[0].set(i, Criteria.INVALID_GAIN);
                }
            }
        }
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ BatchOperator linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
