package com.alibaba.alink.operator.common.linear;

import com.alibaba.alink.common.MLEnvironment;
import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.FeatureColsVectorColMutexRule;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.Internal;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.clustering.KMeansTrainBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.batch.utils.WithTrainInfo;
import com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp;
import com.alibaba.alink.operator.common.optim.Lbfgs;
import com.alibaba.alink.operator.common.optim.Optimizer;
import com.alibaba.alink.operator.common.optim.OptimizerFactory;
import com.alibaba.alink.operator.common.optim.Owlqn;
import com.alibaba.alink.operator.common.optim.activeSet.ConstraintVariable;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.regression.LassoRegTrainParams;
import com.alibaba.alink.params.regression.LinearSvrTrainParams;
import com.alibaba.alink.params.regression.RidgeRegTrainParams;
import com.alibaba.alink.params.shared.colname.HasFeatureCols;
import com.alibaba.alink.params.shared.colname.HasVectorCol;
import com.alibaba.alink.params.shared.linear.HasL1;
import com.alibaba.alink.params.shared.linear.HasL2;
import com.alibaba.alink.params.shared.linear.HasWithIntercept;
import com.alibaba.alink.params.shared.linear.LinearTrainParams;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.AbstractRichFunction;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.operators.FlatMapOperator;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.MapPartitionOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple5;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(value = PortType.MODEL, isOptional = true)})
@OutputPorts(values = {@PortSpec(PortType.MODEL), @PortSpec(value = PortType.DATA, desc = PortDesc.MODEL_INFO), @PortSpec(value = PortType.DATA, desc = PortDesc.FEATURE_IMPORTANCE), @PortSpec(value = PortType.DATA, desc = PortDesc.MODEL_WEIGHT)})
@FeatureColsVectorColMutexRule
@Internal
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "featureCols", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES}), @ParamSelectColumnSpec(name = "labelCol"), @ParamSelectColumnSpec(name = "weightCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES})})
/* loaded from: input_file:com/alibaba/alink/operator/common/linear/BaseLinearModelTrainBatchOp.class */
public abstract class BaseLinearModelTrainBatchOp<T extends BaseLinearModelTrainBatchOp<T>> extends BatchOperator<T> implements WithTrainInfo<LinearModelTrainInfo, T> {
    private static final long serialVersionUID = 6162495789625212086L;
    private final String modelName;
    private final LinearModelType linearModelType;
    private static final String META = "meta";
    private static final String MEAN_VAR = "meanVar";
    private static final String LABEL_VALUES = "labelValues";
    public static Comparator<Tuple3<String, Double, Double>> compare;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:com/alibaba/alink/operator/common/linear/BaseLinearModelTrainBatchOp$BuildModelFromCoefs.class */
    public static class BuildModelFromCoefs extends AbstractRichFunction implements MapPartitionFunction<Tuple2<DenseVector, double[]>, Row> {
        private static final long serialVersionUID = -8526938457839413291L;
        private Params meta;
        private final String[] featureNames;
        private final String[] featureColTypes;
        private final TypeInformation<?> labelType;
        private DenseVector[] meanVar;
        private final boolean hasIntercept;
        private final boolean standardization;

        public BuildModelFromCoefs(TypeInformation<?> typeInformation, String[] strArr, boolean z, boolean z2, String[] strArr2) {
            this.labelType = typeInformation;
            this.featureNames = strArr;
            this.standardization = z;
            this.hasIntercept = z2;
            this.featureColTypes = strArr2;
        }

        public void open(Configuration configuration) throws Exception {
            this.meta = (Params) getRuntimeContext().getBroadcastVariable(BaseLinearModelTrainBatchOp.META).get(0);
            this.meta.set((ParamInfo<ParamInfo<String[]>>) ModelParamName.FEATURE_TYPES, (ParamInfo<String[]>) this.featureColTypes);
            if (LinearModelType.AFT.equals(this.meta.get(ModelParamName.LINEAR_MODEL_TYPE))) {
                this.meanVar = null;
            } else {
                this.meanVar = (DenseVector[]) getRuntimeContext().getBroadcastVariable(BaseLinearModelTrainBatchOp.MEAN_VAR).get(0);
            }
        }

        public void mapPartition(Iterable<Tuple2<DenseVector, double[]>> iterable, Collector<Row> collector) throws Exception {
            Iterator<Tuple2<DenseVector, double[]>> it = iterable.iterator();
            while (it.hasNext()) {
                new LinearModelDataConverter(this.labelType).save(BaseLinearModelTrainBatchOp.buildLinearModelData(this.meta, this.featureNames, this.labelType, this.meanVar, this.hasIntercept, this.standardization, it.next()), collector);
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/linear/BaseLinearModelTrainBatchOp$CreateMeta.class */
    public static class CreateMeta implements MapPartitionFunction<Object[], Params> {
        private static final long serialVersionUID = 536971312646228170L;
        private final String modelName;
        private final LinearModelType modelType;
        private final boolean hasInterceptItem;
        private final String vectorColName;
        private final String labelName;

        public CreateMeta(String str, LinearModelType linearModelType, Params params) {
            this.modelName = str;
            this.modelType = linearModelType;
            this.hasInterceptItem = ((Boolean) params.get(LinearTrainParams.WITH_INTERCEPT)).booleanValue();
            this.vectorColName = (String) params.get(LinearTrainParams.VECTOR_COL);
            this.labelName = (String) params.get(LinearTrainParams.LABEL_COL);
        }

        public void mapPartition(Iterable<Object[]> iterable, Collector<Params> collector) throws Exception {
            Object[] next = iterable.iterator().next();
            Params params = new Params();
            params.set((ParamInfo<ParamInfo<String>>) ModelParamName.MODEL_NAME, (ParamInfo<String>) this.modelName);
            params.set((ParamInfo<ParamInfo<LinearModelType>>) ModelParamName.LINEAR_MODEL_TYPE, (ParamInfo<LinearModelType>) this.modelType);
            if (LinearModelType.LinearReg != this.modelType && LinearModelType.SVR != this.modelType && LinearModelType.AFT != this.modelType) {
                params.set((ParamInfo<ParamInfo<Object[]>>) ModelParamName.LABEL_VALUES, (ParamInfo<Object[]>) next);
            }
            params.set((ParamInfo<ParamInfo<Boolean>>) ModelParamName.HAS_INTERCEPT_ITEM, (ParamInfo<Boolean>) Boolean.valueOf(this.hasInterceptItem));
            params.set((ParamInfo<ParamInfo<String>>) ModelParamName.VECTOR_COL_NAME, (ParamInfo<String>) this.vectorColName);
            params.set((ParamInfo<ParamInfo<String>>) LinearTrainParams.LABEL_COL, (ParamInfo<String>) this.labelName);
            collector.collect(params);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/linear/BaseLinearModelTrainBatchOp$Transform.class */
    public static class Transform extends RichMapPartitionFunction<Row, Tuple3<Double, Object, Vector>> {
        private static final long serialVersionUID = 4360321564414289067L;
        private final boolean isRegProc;
        private final int weightIdx;
        private final int vecIdx;
        private final int labelIdx;
        private final int[] featureIndices;
        private final boolean hasIntercept;
        private final boolean calcMeanVar;
        private boolean hasSparseVector = false;
        private boolean hasDenseVector = false;
        private boolean hasNull = false;
        private boolean hasLabelNull = false;
        private final Map<Integer, double[]> meanVarMap = new HashMap();

        public Transform(boolean z, int i, int i2, int[] iArr, int i3, boolean z2, boolean z3) {
            this.isRegProc = z;
            this.weightIdx = i;
            this.vecIdx = i2;
            this.featureIndices = iArr;
            this.labelIdx = i3;
            this.hasIntercept = z2;
            this.calcMeanVar = z3;
        }

        public void close() throws Exception {
            super.close();
            if (this.hasNull) {
                throw new AkIllegalDataException("The input data has null values, please check it!");
            }
            if (this.hasLabelNull) {
                throw new AkIllegalDataException("The input labels has null values, please check it!");
            }
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Tuple3<Double, Object, Vector>> collector) throws Exception {
            DenseVector denseVector;
            HashSet hashSet = new HashSet();
            int i = -1;
            double d = 0.0d;
            DenseVector denseVector2 = null;
            if (this.featureIndices != null) {
                i = this.hasIntercept ? this.featureIndices.length + 1 : this.featureIndices.length;
                denseVector = this.calcMeanVar ? new DenseVector((3 * i) + 1) : new DenseVector(1);
            } else {
                denseVector = this.calcMeanVar ? null : new DenseVector(1);
            }
            for (Row row : iterable) {
                d += 1.0d;
                Double valueOf = Double.valueOf(this.weightIdx != -1 ? ((Number) row.getField(this.weightIdx)).doubleValue() : 1.0d);
                Object field = row.getField(this.labelIdx);
                if (null == field) {
                    this.hasLabelNull = true;
                }
                if (this.isRegProc) {
                    hashSet.add(Double.valueOf(Criteria.INVALID_GAIN));
                } else {
                    hashSet.add(field);
                }
                if (this.featureIndices == null) {
                    Vector vector = VectorUtil.getVector(row.getField(this.vecIdx));
                    AkPreconditions.checkState(vector != null, "Vector for linear model train is null, please check your input data.");
                    if (vector instanceof SparseVector) {
                        this.hasSparseVector = true;
                        if (this.hasIntercept) {
                            Vector prefix = vector.prefix(1.0d);
                            int[] indices = ((SparseVector) prefix).getIndices();
                            double[] values = ((SparseVector) prefix).getValues();
                            for (int i2 = 0; i2 < indices.length; i2++) {
                                i = Math.max(prefix.size(), Math.max(i, indices[i2] + 1));
                                if (this.calcMeanVar) {
                                    if (this.meanVarMap.containsKey(Integer.valueOf(indices[i2]))) {
                                        double[] dArr = this.meanVarMap.get(Integer.valueOf(indices[i2]));
                                        dArr[0] = Math.max(dArr[0], Math.abs(values[i2]));
                                    } else {
                                        this.meanVarMap.put(Integer.valueOf(indices[i2]), new double[]{Math.abs(values[i2])});
                                    }
                                }
                            }
                            collector.collect(Tuple3.of(valueOf, field, prefix));
                        } else {
                            int[] indices2 = ((SparseVector) vector).getIndices();
                            double[] values2 = ((SparseVector) vector).getValues();
                            for (int i3 = 0; i3 < indices2.length; i3++) {
                                i = Math.max(vector.size(), Math.max(i, indices2[i3] + 1));
                                if (this.calcMeanVar) {
                                    if (this.meanVarMap.containsKey(Integer.valueOf(indices2[i3]))) {
                                        double[] dArr2 = this.meanVarMap.get(Integer.valueOf(indices2[i3]));
                                        dArr2[0] = Math.max(dArr2[0], Math.abs(values2[i3]));
                                    } else {
                                        this.meanVarMap.put(Integer.valueOf(indices2[i3]), new double[]{Math.abs(values2[i3])});
                                    }
                                }
                            }
                            collector.collect(Tuple3.of(valueOf, field, vector));
                        }
                    } else {
                        this.hasDenseVector = true;
                        if (this.hasIntercept) {
                            Vector prefix2 = vector.prefix(1.0d);
                            i = ((DenseVector) prefix2).getData().length;
                            if (this.calcMeanVar) {
                                if (denseVector2 == null) {
                                    denseVector2 = new DenseVector((3 * i) + 1);
                                }
                                for (int i4 = 0; i4 < i; i4++) {
                                    double d2 = prefix2.get(i4);
                                    denseVector2.add(i4, d2);
                                    denseVector2.add(i + i4, d2 * d2);
                                    denseVector2.set((2 * i) + i4, Math.max(denseVector2.get((2 * i) + i4), d2));
                                }
                                denseVector2.add(3 * i, 1.0d);
                            }
                            collector.collect(Tuple3.of(valueOf, field, prefix2));
                        } else {
                            i = ((DenseVector) vector).getData().length;
                            if (this.calcMeanVar) {
                                if (denseVector2 == null) {
                                    denseVector2 = new DenseVector((3 * i) + 1);
                                }
                                for (int i5 = 0; i5 < i; i5++) {
                                    double d3 = vector.get(i5);
                                    denseVector2.add(i5, d3);
                                    denseVector2.add(i + i5, d3 * d3);
                                    denseVector2.set((2 * i) + i5, Math.max(denseVector2.get((2 * i) + i5), d3));
                                }
                                denseVector2.add(3 * i, 1.0d);
                            }
                            collector.collect(Tuple3.of(valueOf, field, vector));
                        }
                    }
                } else if (this.hasIntercept) {
                    DenseVector denseVector3 = new DenseVector(this.featureIndices.length + 1);
                    denseVector3.set(0, 1.0d);
                    if (this.calcMeanVar) {
                        denseVector.add(0, 1.0d);
                        denseVector.add(i, 1.0d);
                    }
                    for (int i6 = 1; i6 < this.featureIndices.length + 1; i6++) {
                        if (row.getField(this.featureIndices[i6 - 1]) == null) {
                            this.hasNull = true;
                        } else {
                            double doubleValue = ((Number) row.getField(this.featureIndices[i6 - 1])).doubleValue();
                            denseVector3.set(i6, doubleValue);
                            if (this.calcMeanVar) {
                                denseVector.add(i6, doubleValue);
                                denseVector.add(i + i6, doubleValue * doubleValue);
                            }
                        }
                    }
                    if (this.calcMeanVar) {
                        denseVector.add(3 * i, 1.0d);
                    }
                    collector.collect(Tuple3.of(valueOf, field, denseVector3));
                } else {
                    DenseVector denseVector4 = new DenseVector(this.featureIndices.length);
                    for (int i7 = 0; i7 < this.featureIndices.length; i7++) {
                        if (row.getField(this.featureIndices[i7]) == null) {
                            this.hasNull = true;
                        } else {
                            double doubleValue2 = ((Number) row.getField(this.featureIndices[i7])).doubleValue();
                            denseVector4.set(i7, doubleValue2);
                            if (this.calcMeanVar) {
                                denseVector.add(i7, doubleValue2);
                                denseVector.add(i + i7, doubleValue2 * doubleValue2);
                            }
                        }
                    }
                    if (this.calcMeanVar) {
                        denseVector.add(3 * i, 1.0d);
                    }
                    collector.collect(Tuple3.of(valueOf, field, denseVector4));
                }
            }
            if (denseVector == null) {
                if (this.hasSparseVector && !this.hasDenseVector) {
                    denseVector = new DenseVector(i + 1);
                    for (Integer num : this.meanVarMap.keySet()) {
                        denseVector.set(num.intValue(), this.meanVarMap.get(num)[0]);
                    }
                } else if (this.hasSparseVector) {
                    denseVector = new DenseVector(i + 1);
                    for (Integer num2 : this.meanVarMap.keySet()) {
                        denseVector.set(num2.intValue(), this.meanVarMap.get(num2)[0]);
                    }
                    for (int i8 = 0; i8 < i; i8++) {
                        denseVector.set(i8, Math.max(denseVector.get(i8), Math.abs(denseVector2.get((2 * i) + i8))));
                    }
                } else {
                    denseVector = denseVector2;
                }
            }
            if (this.hasSparseVector) {
                denseVector.set(denseVector.size() - 1, d);
                collector.collect(Tuple3.of(Double.valueOf(-1.0d), Tuple2.of(Integer.valueOf(i), hashSet.toArray()), denseVector));
            } else if (this.hasDenseVector || this.featureIndices != null) {
                denseVector.set(denseVector.size() - 1, d);
                collector.collect(Tuple3.of(Double.valueOf(-2.0d), Tuple2.of(Integer.valueOf(i), hashSet.toArray()), denseVector));
            }
        }
    }

    public BaseLinearModelTrainBatchOp(Params params, LinearModelType linearModelType, String str) {
        super(params);
        this.modelName = str;
        this.linearModelType = linearModelType;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public T linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> batchOperator;
        BatchOperator<?> batchOperator2 = null;
        if (batchOperatorArr.length == 1) {
            batchOperator = checkAndGetFirst(batchOperatorArr);
        } else {
            batchOperator = batchOperatorArr[0];
            batchOperator2 = batchOperatorArr[1];
        }
        Params params = getParams();
        if (params.contains(HasFeatureCols.FEATURE_COLS) && params.contains(HasVectorCol.VECTOR_COL)) {
            throw new AkIllegalArgumentException("FeatureCols and vectorCol cannot be set at the same time.");
        }
        final boolean isRegProc = getIsRegProc(params, this.linearModelType, this.modelName);
        boolean booleanValue = ((Boolean) params.get(LinearTrainParams.STANDARDIZATION)).booleanValue();
        TypeInformation<?> typeInformation = isRegProc ? Types.DOUBLE : batchOperator.getColTypes()[TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), (String) params.get(LinearTrainParams.LABEL_COL))];
        DataSet<Tuple3<Double, Object, Vector>> transform = transform(batchOperator, params, isRegProc, booleanValue);
        DataSet<Tuple3<DenseVector[], Object[], Integer[]>> utilInfo = getUtilInfo(transform, booleanValue, isRegProc);
        MapOperator map = utilInfo.map(new MapFunction<Tuple3<DenseVector[], Object[], Integer[]>, DenseVector[]>() { // from class: com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp.1
            private static final long serialVersionUID = 7127767376687624403L;

            public DenseVector[] map(Tuple3<DenseVector[], Object[], Integer[]> tuple3) {
                return (DenseVector[]) tuple3.f0;
            }
        });
        MapOperator map2 = utilInfo.map(new MapFunction<Tuple3<DenseVector[], Object[], Integer[]>, Integer>() { // from class: com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp.2
            private static final long serialVersionUID = 2773811388068064638L;

            public Integer map(Tuple3<DenseVector[], Object[], Integer[]> tuple3) {
                return ((Integer[]) tuple3.f2)[0];
            }
        });
        FlatMapOperator flatMap = utilInfo.flatMap(new FlatMapFunction<Tuple3<DenseVector[], Object[], Integer[]>, Object[]>() { // from class: com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp.3
            private static final long serialVersionUID = 5375954526931728363L;

            public void flatMap(Tuple3<DenseVector[], Object[], Integer[]> tuple3, Collector<Object[]> collector) {
                if (!isRegProc) {
                    AkPreconditions.checkState(((Object[]) tuple3.f1).length == 2, "Labels count should be 2 in in linear classification algo.");
                }
                collector.collect(tuple3.f1);
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple3<DenseVector[], Object[], Integer[]>) obj, (Collector<Object[]>) collector);
            }
        });
        DataSet<Tuple2<DenseVector, double[]>> optimize = optimize(params, map2, preProcess(transform, params, isRegProc, map, flatMap, map2), getInitialModel(batchOperator2, map2, map, params, this.linearModelType), this.linearModelType, MLEnvironmentFactory.get(getMLEnvironmentId()));
        Operator parallelism = optimize.mapPartition(new BuildModelFromCoefs(typeInformation, (String[]) params.get(LinearTrainParams.FEATURE_COLS), ((Boolean) params.get(LinearTrainParams.STANDARDIZATION)).booleanValue(), ((Boolean) params.get(LinearTrainParams.WITH_INTERCEPT)).booleanValue(), getFeatureTypes(batchOperator, (String[]) params.get(LinearTrainParams.FEATURE_COLS)))).withBroadcastSet(flatMap.mapPartition(new CreateMeta(this.modelName, this.linearModelType, params)).setParallelism(1), META).withBroadcastSet(map, MEAN_VAR).setParallelism(1);
        setOutput((DataSet<Row>) parallelism, new LinearModelDataConverter(typeInformation).getModelSchema());
        setSideOutputTables(getSideTablesOfCoefficient(optimize.project(new int[]{1}), parallelism, transform, map2, (String[]) params.get(LinearTrainParams.FEATURE_COLS), ((Boolean) params.get(LinearTrainParams.WITH_INTERCEPT)).booleanValue(), getMLEnvironmentId().longValue()));
        return this;
    }

    public static DataSet<DenseVector> getInitialModel(BatchOperator<?> batchOperator, DataSet<Integer> dataSet, DataSet<DenseVector[]> dataSet2, final Params params, final LinearModelType linearModelType) {
        if (batchOperator == null) {
            return null;
        }
        return batchOperator.getDataSet().reduceGroup(new RichGroupReduceFunction<Row, DenseVector>() { // from class: com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp.4
            public void reduce(Iterable<Row> iterable, Collector<DenseVector> collector) {
                int intValue = ((Integer) getRuntimeContext().getBroadcastVariable("featSize").get(0)).intValue();
                DenseVector[] denseVectorArr = (DenseVector[]) getRuntimeContext().getBroadcastVariable(BaseLinearModelTrainBatchOp.MEAN_VAR).get(0);
                ArrayList arrayList = new ArrayList(0);
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next());
                }
                LinearModelData load = new LinearModelDataConverter().load((List<Row>) arrayList);
                if (load.hasInterceptItem != ((Boolean) Params.this.get(HasWithIntercept.WITH_INTERCEPT)).booleanValue()) {
                    throw new AkIllegalArgumentException("Initial linear model is not compatible with parameter setting.InterceptItem parameter setting error.");
                }
                if (load.linearModelType != linearModelType) {
                    throw new AkIllegalArgumentException("Initial linear model is not compatible with parameter setting.linearModelType setting error.");
                }
                if (load.vectorSize != intValue) {
                    throw new AkIllegalDataException("Initial linear model is not compatible with training data.  vector size not equal, vector size in init model is : " + load.vectorSize + " and vector size of train data is : " + intValue);
                }
                int size = denseVectorArr[0].size();
                if (load.hasInterceptItem) {
                    double d = 0.0d;
                    for (int i = 1; i < size; i++) {
                        d += load.coefVector.get(i) * denseVectorArr[0].get(i);
                        load.coefVector.set(i, load.coefVector.get(i) * denseVectorArr[1].get(i));
                    }
                    load.coefVector.set(0, load.coefVector.get(0) + d);
                } else {
                    for (int i2 = 0; i2 < size; i2++) {
                        load.coefVector.set(i2, load.coefVector.get(i2) * denseVectorArr[1].get(i2));
                    }
                }
                collector.collect(load.coefVector);
            }
        }).withBroadcastSet(dataSet, "featSize").withBroadcastSet(dataSet2, MEAN_VAR);
    }

    public static DataSet<Tuple3<DenseVector[], Object[], Integer[]>> getUtilInfo(DataSet<Tuple3<Double, Object, Vector>> dataSet, final boolean z, final boolean z2) {
        return dataSet.filter(new FilterFunction<Tuple3<Double, Object, Vector>>() { // from class: com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp.6
            private static final long serialVersionUID = 4129133776653527498L;

            public boolean filter(Tuple3<Double, Object, Vector> tuple3) {
                return ((Double) tuple3.f0).doubleValue() < Criteria.INVALID_GAIN;
            }
        }).reduceGroup(new GroupReduceFunction<Tuple3<Double, Object, Vector>, Tuple3<DenseVector[], Object[], Integer[]>>() { // from class: com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp.5
            private static final long serialVersionUID = -4819473589070441623L;
            static final /* synthetic */ boolean $assertionsDisabled;

            public void reduce(Iterable<Tuple3<Double, Object, Vector>> iterable, Collector<Tuple3<DenseVector[], Object[], Integer[]>> collector) {
                int i = -1;
                int i2 = -1;
                HashSet hashSet = new HashSet();
                DenseVector denseVector = null;
                DenseVector denseVector2 = null;
                boolean z3 = false;
                boolean z4 = false;
                boolean z5 = false;
                boolean z6 = false;
                ArrayList<Tuple3> arrayList = new ArrayList();
                ArrayList<Tuple3> arrayList2 = new ArrayList();
                int i3 = 0;
                for (Tuple3<Double, Object, Vector> tuple3 : iterable) {
                    if (((Double) tuple3.f0).doubleValue() == -1.0d) {
                        arrayList2.add(tuple3);
                        z3 = true;
                    } else if (((Double) tuple3.f0).doubleValue() == -2.0d) {
                        arrayList.add(tuple3);
                        z4 = true;
                    }
                    i3 = (int) (i3 + ((Vector) tuple3.f2).get(((Vector) tuple3.f2).size() - 1));
                }
                if (z3) {
                    for (Tuple3 tuple32 : arrayList2) {
                        Tuple2 tuple2 = (Tuple2) tuple32.f1;
                        Collections.addAll(hashSet, (Object[]) tuple2.f1);
                        if (denseVector == null) {
                            denseVector = (DenseVector) tuple32.f2;
                            z5 = denseVector != null && denseVector.size() > 1;
                            i = ((Integer) tuple2.f0).intValue();
                        } else if (((Integer) tuple2.f0).intValue() == i) {
                            if (z5) {
                                for (int i4 = 0; i4 < denseVector.size(); i4++) {
                                    denseVector.set(i4, Math.max(denseVector.get(i4), Math.abs(((Vector) tuple32.f2).get(i4))));
                                }
                            }
                        } else if (z5) {
                            if (((Integer) tuple2.f0).intValue() < i) {
                                for (int i5 = 0; i5 < ((Double) tuple32.f0).doubleValue(); i5++) {
                                    denseVector.set(i5, Math.max(denseVector.get(i5), Math.abs(((Vector) tuple32.f2).get(i5))));
                                }
                            } else {
                                for (int i6 = 0; i6 < i; i6++) {
                                    ((Vector) tuple32.f2).set(i6, Math.max(Math.abs(((Vector) tuple32.f2).get(i6)), denseVector.get(i6)));
                                }
                                denseVector = (DenseVector) tuple32.f2;
                                i = ((Integer) tuple2.f0).intValue();
                            }
                        }
                    }
                }
                if (z4) {
                    for (Tuple3 tuple33 : arrayList) {
                        Tuple2 tuple22 = (Tuple2) tuple33.f1;
                        hashSet.addAll(Arrays.asList((Object[]) tuple22.f1));
                        if (denseVector2 == null) {
                            denseVector2 = (DenseVector) tuple33.f2;
                            z6 = denseVector2 != null && denseVector2.size() > 1;
                            i2 = ((Integer) tuple22.f0).intValue();
                        } else if (((Integer) tuple22.f0).intValue() == i2) {
                            if (z6) {
                                for (int i7 = 0; i7 < i2; i7++) {
                                    denseVector2.set(i7, denseVector2.get(i7) + ((Vector) tuple33.f2).get(i7));
                                    denseVector2.set(i2 + i7, denseVector2.get(i2 + i7) + ((Vector) tuple33.f2).get(i2 + i7));
                                    denseVector2.set((2 * i2) + i7, Math.max(denseVector2.get((2 * i2) + i7), Math.abs(((Vector) tuple33.f2).get((2 * i2) + i7))));
                                }
                                denseVector2.set(3 * i2, denseVector2.get(3 * i2) + ((Vector) tuple33.f2).get(3 * i2));
                            }
                        } else if (((Integer) tuple22.f0).intValue() >= i2) {
                            if (z6) {
                                for (int i8 = 0; i8 < i2; i8++) {
                                    ((Vector) tuple33.f2).set(i8, denseVector2.get(i8) + ((Vector) tuple33.f2).get(i8));
                                    ((Vector) tuple33.f2).set(i2 + i8, denseVector2.get(i2 + i8) + ((Vector) tuple33.f2).get(i2 + i8));
                                    ((Vector) tuple33.f2).set((2 * i2) + i8, Math.max(denseVector2.get((2 * i2) + i8), Math.abs(((Vector) tuple33.f2).get((2 * i2) + i8))));
                                }
                                ((Vector) tuple33.f2).set(3 * i2, denseVector2.get(3 * i2) + ((Vector) tuple33.f2).get(3 * i2));
                                denseVector2 = (DenseVector) tuple33.f2;
                            }
                            i2 = ((Integer) tuple22.f0).intValue();
                        } else if (z6) {
                            for (int i9 = 0; i9 < ((Integer) tuple22.f0).intValue(); i9++) {
                                denseVector2.set(i9, denseVector2.get(i9) + ((Vector) tuple33.f2).get(i9));
                                denseVector2.set(i2 + i9, denseVector2.get(i2 + i9) + ((Vector) tuple33.f2).get(i2 + i9));
                                denseVector2.set((2 * i2) + i9, Math.max(denseVector2.get((2 * i2) + i9), Math.abs(((Vector) tuple33.f2).get((2 * i2) + i9))));
                            }
                            denseVector2.set(3 * i2, denseVector2.get(3 * i2) + ((Vector) tuple33.f2).get(3 * i2));
                        }
                    }
                }
                boolean z7 = z6 || z5;
                if (z3 && z4) {
                    if (z7) {
                        if (!$assertionsDisabled && denseVector == null) {
                            throw new AssertionError();
                        }
                        if (!$assertionsDisabled && denseVector2 == null) {
                            throw new AssertionError();
                        }
                        if (denseVector.size() >= denseVector2.size() / 3) {
                            for (int i10 = 0; i10 < i; i10++) {
                                denseVector.set(i10, Math.max(denseVector.get(i10), Math.abs(denseVector2.get((2 * i) + i10))));
                            }
                        } else {
                            DenseVector denseVector3 = new DenseVector(i2);
                            for (int i11 = 0; i11 < denseVector.size(); i11++) {
                                denseVector3.set(i11, denseVector.get(i11));
                            }
                            for (int i12 = 0; i12 < i2; i12++) {
                                denseVector3.set(i12, Math.max(denseVector3.get(i12), Math.abs(denseVector2.get((2 * i2) + i12))));
                            }
                            denseVector = denseVector3;
                        }
                    }
                } else if (z4) {
                    denseVector = denseVector2;
                }
                int max = Math.max(i, i2);
                DenseVector[] denseVectorArr = new DenseVector[2];
                Object[] array = z2 ? hashSet.toArray() : BaseLinearModelTrainBatchOp.orderLabels(hashSet);
                denseVectorArr[0] = z7 ? new DenseVector(max) : new DenseVector(0);
                denseVectorArr[1] = z7 ? new DenseVector(max) : new DenseVector(0);
                if (z7) {
                    if (z3) {
                        denseVectorArr[1] = denseVector;
                        BaseLinearModelTrainBatchOp.modifyMeanVar(z, denseVectorArr);
                    } else {
                        for (int i13 = 0; i13 < max; i13++) {
                            denseVectorArr[0].set(i13, denseVector.get(i13) / denseVector.get(3 * max));
                            denseVectorArr[1].set(i13, denseVector.get(max + i13) - ((denseVector.get(3 * max) * denseVectorArr[0].get(i13)) * denseVectorArr[0].get(i13)));
                        }
                        for (int i14 = 0; i14 < max; i14++) {
                            denseVectorArr[1].set(i14, Math.max(Criteria.INVALID_GAIN, denseVectorArr[1].get(i14)));
                            denseVectorArr[1].set(i14, Math.sqrt(denseVectorArr[1].get(i14) / (denseVector.get(3 * max) - 1.0d)));
                        }
                        BaseLinearModelTrainBatchOp.modifyMeanVar(z, denseVectorArr);
                    }
                }
                collector.collect(Tuple3.of(denseVectorArr, array, new Integer[]{Integer.valueOf(max), Integer.valueOf(i3)}));
            }

            static {
                $assertionsDisabled = !BaseLinearModelTrainBatchOp.class.desiredAssertionStatus();
            }
        });
    }

    public static Table[] getSideTablesOfCoefficient(DataSet<Tuple1<double[]>> dataSet, DataSet<Row> dataSet2, DataSet<Tuple3<Double, Object, Vector>> dataSet3, DataSet<Integer> dataSet4, final String[] strArr, final boolean z, long j) {
        SingleInputUdfOperator withBroadcastSet = dataSet3.mapPartition(new RichMapPartitionFunction<Tuple3<Double, Object, Vector>, Tuple3<Integer, double[], double[]>>() { // from class: com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp.10
            private static final long serialVersionUID = 8785824618242390100L;
            private int vectorSize;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                this.vectorSize = ((Integer) getRuntimeContext().getBroadcastVariable(KMeansTrainBatchOp.VECTOR_SIZE).get(0)).intValue();
                if (z) {
                    this.vectorSize--;
                }
            }

            public void mapPartition(Iterable<Tuple3<Double, Object, Vector>> iterable, Collector<Tuple3<Integer, double[], double[]>> collector) {
                int i = 0;
                double[] dArr = new double[this.vectorSize];
                double[] dArr2 = new double[this.vectorSize];
                if (strArr == null) {
                    for (Tuple3<Double, Object, Vector> tuple3 : iterable) {
                        if (((Double) tuple3.f0).doubleValue() >= Criteria.INVALID_GAIN) {
                            if (tuple3.f2 instanceof SparseVector) {
                                SparseVector sparseVector = (SparseVector) tuple3.f2;
                                sparseVector.setSize(this.vectorSize);
                                double[] values = sparseVector.getValues();
                                int[] indices = sparseVector.getIndices();
                                for (int i2 = 0; i2 < values.length; i2++) {
                                    if (!z) {
                                        int i3 = indices[i2];
                                        dArr[i3] = dArr[i3] + values[i2];
                                        int i4 = indices[i2];
                                        dArr2[i4] = dArr2[i4] + (values[i2] * values[i2]);
                                    } else if (indices[i2] > 0) {
                                        int i5 = indices[i2] - 1;
                                        dArr[i5] = dArr[i5] + values[i2];
                                        int i6 = indices[i2] - 1;
                                        dArr2[i6] = dArr2[i6] + (values[i2] * values[i2]);
                                    }
                                }
                                i++;
                            } else {
                                for (int i7 = 0; i7 < this.vectorSize; i7++) {
                                    double d = ((Vector) tuple3.f2).get(i7 + (z ? 1 : 0));
                                    int i8 = i7;
                                    dArr[i8] = dArr[i8] + d;
                                    int i9 = i7;
                                    dArr2[i9] = dArr2[i9] + (d * d);
                                }
                                i++;
                            }
                        }
                    }
                } else {
                    for (Tuple3<Double, Object, Vector> tuple32 : iterable) {
                        if (((Double) tuple32.f0).doubleValue() >= Criteria.INVALID_GAIN) {
                            for (int i10 = 0; i10 < this.vectorSize; i10++) {
                                double d2 = ((Vector) tuple32.f2).get(i10 + (z ? 1 : 0));
                                int i11 = i10;
                                dArr[i11] = dArr[i11] + d2;
                                int i12 = i10;
                                dArr2[i12] = dArr2[i12] + (d2 * d2);
                            }
                            i++;
                        }
                    }
                }
                collector.collect(Tuple3.of(Integer.valueOf(i), dArr, dArr2));
            }
        }).withBroadcastSet(dataSet4, KMeansTrainBatchOp.VECTOR_SIZE).reduce(new ReduceFunction<Tuple3<Integer, double[], double[]>>() { // from class: com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp.9
            private static final long serialVersionUID = 7062783877162095989L;

            public Tuple3<Integer, double[], double[]> reduce(Tuple3<Integer, double[], double[]> tuple3, Tuple3<Integer, double[], double[]> tuple32) {
                tuple32.f0 = Integer.valueOf(((Integer) tuple3.f0).intValue() + ((Integer) tuple32.f0).intValue());
                for (int i = 0; i < ((double[]) tuple3.f1).length; i++) {
                    ((double[]) tuple32.f1)[i] = ((double[]) tuple3.f1)[i] + ((double[]) tuple32.f1)[i];
                    ((double[]) tuple32.f2)[i] = ((double[]) tuple3.f2)[i] + ((double[]) tuple32.f2)[i];
                }
                return tuple32;
            }
        }).flatMap(new RichFlatMapFunction<Tuple3<Integer, double[], double[]>, Tuple5<String, String[], double[], double[], double[]>>() { // from class: com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp.8
            private static final long serialVersionUID = 7815111101106759520L;
            private DenseVector coefVec;
            private Tuple2<DenseVector, double[]> model;
            private double[] cinfo;
            private Params metaInfo;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                this.cinfo = (double[]) ((Tuple1) getRuntimeContext().getBroadcastVariable("cinfo").get(0)).f0;
                LinearModelData linearModelData = (LinearModelData) getRuntimeContext().getBroadcastVariable(OptimVariable.model).get(0);
                this.coefVec = linearModelData.coefVector;
                this.metaInfo = linearModelData.getMetaInfo();
            }

            public void flatMap(Tuple3<Integer, double[], double[]> tuple3, Collector<Tuple5<String, String[], double[], double[], double[]>> collector) {
                String[] strArr2;
                if (strArr == null) {
                    strArr2 = new String[this.coefVec.size() - (z ? 1 : 0)];
                    for (int i = 0; i < strArr2.length; i++) {
                        strArr2[i] = String.valueOf(i);
                    }
                } else {
                    strArr2 = strArr;
                }
                double[] dArr = z ? new double[this.coefVec.size() - 1] : new double[this.coefVec.size()];
                for (int i2 = 0; i2 < ((double[]) tuple3.f1).length; i2++) {
                    double intValue = ((double[]) tuple3.f1)[i2] / ((Integer) tuple3.f0).intValue();
                    dArr[i2] = Math.abs(this.coefVec.get(i2 + (z ? 1 : 0)) * (((Integer) tuple3.f0).intValue() == 1 ? 0.0d : Math.sqrt(Math.max(Criteria.INVALID_GAIN, ((double[]) tuple3.f2)[i2] - ((((Integer) tuple3.f0).intValue() * intValue) * intValue)) / (((Integer) tuple3.f0).intValue() - 1))));
                }
                collector.collect(Tuple5.of(JsonConverter.toJson(this.metaInfo), strArr2, this.coefVec.getData(), dArr, this.cinfo));
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple3<Integer, double[], double[]>) obj, (Collector<Tuple5<String, String[], double[], double[], double[]>>) collector);
            }
        }).setParallelism(1).withBroadcastSet(dataSet2.mapPartition(new MapPartitionFunction<Row, LinearModelData>() { // from class: com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp.7
            private static final long serialVersionUID = 2063366042018382802L;

            public void mapPartition(Iterable<Row> iterable, Collector<LinearModelData> collector) {
                ArrayList arrayList = new ArrayList();
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next());
                }
                collector.collect(new LinearModelDataConverter().load((List<Row>) arrayList));
            }
        }).setParallelism(1), OptimVariable.model).withBroadcastSet(dataSet, "cinfo");
        MapPartitionOperator mapPartition = withBroadcastSet.mapPartition(new MapPartitionFunction<Tuple5<String, String[], double[], double[], double[]>, Row>() { // from class: com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp.11
            private static final long serialVersionUID = -3263497114974298286L;

            public void mapPartition(Iterable<Tuple5<String, String[], double[], double[], double[]>> iterable, Collector<Row> collector) {
                String[] strArr2 = null;
                double[] dArr = null;
                for (Tuple5<String, String[], double[], double[], double[]> tuple5 : iterable) {
                    strArr2 = (String[]) tuple5.f1;
                    dArr = (double[]) tuple5.f3;
                }
                for (int i = 0; i < ((String[]) Objects.requireNonNull(strArr2)).length; i++) {
                    collector.collect(Row.of(new Object[]{strArr2[i], Double.valueOf(dArr[i])}));
                }
            }
        });
        MapPartitionOperator mapPartition2 = withBroadcastSet.mapPartition(new MapPartitionFunction<Tuple5<String, String[], double[], double[], double[]>, Row>() { // from class: com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp.12
            private static final long serialVersionUID = -6164289179429722407L;
            static final /* synthetic */ boolean $assertionsDisabled;

            public void mapPartition(Iterable<Tuple5<String, String[], double[], double[], double[]>> iterable, Collector<Row> collector) {
                String[] strArr2 = null;
                double[] dArr = null;
                for (Tuple5<String, String[], double[], double[], double[]> tuple5 : iterable) {
                    strArr2 = (String[]) tuple5.f1;
                    dArr = (double[]) tuple5.f2;
                }
                if (!$assertionsDisabled && dArr == null) {
                    throw new AssertionError();
                }
                if (dArr.length == strArr2.length) {
                    for (int i = 0; i < strArr2.length; i++) {
                        collector.collect(Row.of(new Object[]{strArr2[i], Double.valueOf(dArr[i])}));
                    }
                    return;
                }
                collector.collect(Row.of(new Object[]{"_intercept_", Double.valueOf(dArr[0])}));
                for (int i2 = 0; i2 < strArr2.length; i2++) {
                    collector.collect(Row.of(new Object[]{strArr2[i2], Double.valueOf(dArr[i2 + 1])}));
                }
            }

            static {
                $assertionsDisabled = !BaseLinearModelTrainBatchOp.class.desiredAssertionStatus();
            }
        });
        return new Table[]{DataSetConversionUtil.toTable(Long.valueOf(j), (DataSet<Row>) withBroadcastSet.mapPartition(new MapPartitionFunction<Tuple5<String, String[], double[], double[], double[]>, Row>() { // from class: com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp.13
            private static final long serialVersionUID = -6164289179429722407L;
            private static final int NUM_COLLECT_THRESHOLD = 10000;

            public void mapPartition(Iterable<Tuple5<String, String[], double[], double[], double[]>> iterable, Collector<Row> collector) {
                for (Tuple5<String, String[], double[], double[], double[]> tuple5 : iterable) {
                    if (((String[]) tuple5.f1).length < 10000) {
                        collector.collect(Row.of(new Object[]{0L, tuple5.f0}));
                        collector.collect(Row.of(new Object[]{1L, JsonConverter.toJson(tuple5.f1)}));
                        collector.collect(Row.of(new Object[]{2L, JsonConverter.toJson(tuple5.f2)}));
                        collector.collect(Row.of(new Object[]{3L, JsonConverter.toJson(tuple5.f3)}));
                        collector.collect(Row.of(new Object[]{4L, JsonConverter.toJson(tuple5.f4)}));
                    } else {
                        ArrayList arrayList = new ArrayList(((String[]) tuple5.f1).length);
                        int i = z ? 1 : 0;
                        for (int i2 = 0; i2 < ((String[]) tuple5.f1).length; i2++) {
                            arrayList.add(Tuple3.of(((String[]) tuple5.f1)[i2], Double.valueOf(((double[]) tuple5.f2)[i2 + i]), Double.valueOf(((double[]) tuple5.f3)[i2])));
                        }
                        arrayList.sort(BaseLinearModelTrainBatchOp.compare);
                        String[] strArr2 = new String[10000];
                        double[] dArr = new double[10000];
                        double[] dArr2 = new double[10000];
                        for (int i3 = 0; i3 < 5000; i3++) {
                            strArr2[i3] = (String) ((Tuple3) arrayList.get(i3)).f0;
                            dArr[i3] = ((Double) ((Tuple3) arrayList.get(i3)).f1).doubleValue();
                            dArr2[i3] = ((Double) ((Tuple3) arrayList.get(i3)).f2).doubleValue();
                            int length = (((String[]) tuple5.f1).length - i3) - 1;
                            int i4 = (10000 - i3) - 1;
                            strArr2[i4] = (String) ((Tuple3) arrayList.get(length)).f0;
                            dArr[i4] = ((Double) ((Tuple3) arrayList.get(length)).f1).doubleValue();
                            dArr2[i4] = ((Double) ((Tuple3) arrayList.get(length)).f2).doubleValue();
                        }
                        collector.collect(Row.of(new Object[]{0L, tuple5.f0}));
                        collector.collect(Row.of(new Object[]{1L, JsonConverter.toJson(strArr2)}));
                        collector.collect(Row.of(new Object[]{2L, JsonConverter.toJson(dArr)}));
                        collector.collect(Row.of(new Object[]{3L, JsonConverter.toJson(dArr2)}));
                        collector.collect(Row.of(new Object[]{4L, JsonConverter.toJson(tuple5.f4)}));
                    }
                }
            }
        }), new TableSchema(new String[]{"id", "info"}, new TypeInformation[]{Types.LONG, Types.STRING})), DataSetConversionUtil.toTable(Long.valueOf(j), (DataSet<Row>) mapPartition, new TableSchema(new String[]{"col_name", "importance"}, new TypeInformation[]{Types.STRING, Types.DOUBLE})), DataSetConversionUtil.toTable(Long.valueOf(j), (DataSet<Row>) mapPartition2, new TableSchema(new String[]{"col_name", ConstraintVariable.weight}, new TypeInformation[]{Types.STRING, Types.DOUBLE}))};
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Object[] orderLabels(Iterable<Object> iterable) {
        ArrayList arrayList = new ArrayList();
        Iterator<Object> it = iterable.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next());
        }
        Object[] array = arrayList.toArray(new Object[0]);
        String obj = array[0].toString();
        String obj2 = array[1].toString();
        if (array[1].toString().equals(obj2.compareTo(obj) > 0 ? obj2 : obj)) {
            Object obj3 = array[0];
            array[0] = array[1];
            array[1] = obj3;
        }
        return array;
    }

    public static DataSet<Tuple2<DenseVector, double[]>> optimize(Params params, DataSet<Integer> dataSet, DataSet<Tuple3<Double, Double, Vector>> dataSet2, DataSet<DenseVector> dataSet3, final LinearModelType linearModelType, MLEnvironment mLEnvironment) {
        MapOperator fromElements;
        boolean booleanValue = ((Boolean) params.get(LinearTrainParams.WITH_INTERCEPT)).booleanValue();
        String[] strArr = (String[]) params.get(LinearTrainParams.FEATURE_COLS);
        String str = (String) params.get(LinearTrainParams.VECTOR_COL);
        if ("".equals(str)) {
            str = null;
        }
        if (ArrayUtils.isEmpty(strArr)) {
            strArr = null;
        }
        if (str != null && str.length() != 0) {
            fromElements = dataSet.map(new MapFunction<Integer, Integer>() { // from class: com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp.14
                private static final long serialVersionUID = 5249103591725412746L;

                public Integer map(Integer num) {
                    return Integer.valueOf(num.intValue() + (LinearModelType.this.equals(LinearModelType.AFT) ? 1 : 0));
                }
            });
        } else {
            if (!$assertionsDisabled && strArr == null) {
                throw new AssertionError();
            }
            ExecutionEnvironment executionEnvironment = mLEnvironment.getExecutionEnvironment();
            Integer[] numArr = new Integer[1];
            numArr[0] = Integer.valueOf(strArr.length + (booleanValue ? 1 : 0) + (linearModelType.equals(LinearModelType.AFT) ? 1 : 0));
            fromElements = executionEnvironment.fromElements(numArr);
        }
        DataSource fromElements2 = mLEnvironment.getExecutionEnvironment().fromElements(new OptimObjFunc[]{OptimObjFunc.getObjFunction(linearModelType, params)});
        Optimizer create = params.contains(LinearTrainParams.OPTIM_METHOD) ? OptimizerFactory.create(fromElements2, dataSet2, fromElements, params, (LinearTrainParams.OptimMethod) params.get(LinearTrainParams.OPTIM_METHOD)) : ((Double) params.get(HasL1.L_1)).doubleValue() > Criteria.INVALID_GAIN ? new Owlqn(fromElements2, dataSet2, fromElements, params) : new Lbfgs(fromElements2, dataSet2, fromElements, params);
        create.initCoefWith(dataSet3);
        return create.optimize();
    }

    public static DataSet<Tuple3<Double, Object, Vector>> transform(BatchOperator<?> batchOperator, Params params, boolean z, boolean z2) {
        String[] strArr = (String[]) params.get(LinearTrainParams.FEATURE_COLS);
        String str = (String) params.get(LinearTrainParams.LABEL_COL);
        String str2 = (String) params.get(LinearTrainParams.WEIGHT_COL);
        String str3 = (String) params.get(LinearTrainParams.VECTOR_COL);
        boolean booleanValue = ((Boolean) params.get(LinearTrainParams.WITH_INTERCEPT)).booleanValue();
        TableSchema schema = batchOperator.getSchema();
        if (null == strArr && null == str3) {
            strArr = TableUtil.getNumericCols(schema, new String[]{str});
            params.set((ParamInfo<ParamInfo<String[]>>) LinearTrainParams.FEATURE_COLS, (ParamInfo<String[]>) strArr);
        }
        int[] iArr = null;
        int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(schema.getFieldNames(), str);
        if (strArr != null) {
            iArr = new int[strArr.length];
            for (int i = 0; i < strArr.length; i++) {
                int findColIndexWithAssertAndHint2 = TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), strArr[i]);
                iArr[i] = findColIndexWithAssertAndHint2;
                TypeInformation typeInformation = batchOperator.getSchema().getFieldTypes()[findColIndexWithAssertAndHint2];
                AkPreconditions.checkState(TableUtil.isSupportedNumericType(typeInformation), "linear algorithm only support numerical data type. Current type is : " + typeInformation);
            }
        }
        return batchOperator.getDataSet().mapPartition(new Transform(z, str2 != null ? TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str2) : -1, str3 != null ? TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str3) : -1, iArr, findColIndexWithAssertAndHint, booleanValue, z2));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static String[] getFeatureTypes(BatchOperator<?> batchOperator, String[] strArr) {
        if (strArr == null) {
            return null;
        }
        String[] strArr2 = new String[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            TypeInformation typeInformation = batchOperator.getSchema().getFieldTypes()[TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), strArr[i])];
            if (typeInformation.equals(Types.DOUBLE)) {
                strArr2[i] = "double";
            } else if (typeInformation.equals(Types.FLOAT)) {
                strArr2[i] = "float";
            } else if (typeInformation.equals(Types.LONG)) {
                strArr2[i] = "long";
            } else if (typeInformation.equals(Types.INT)) {
                strArr2[i] = "int";
            } else if (typeInformation.equals(Types.SHORT)) {
                strArr2[i] = "short";
            } else if (typeInformation.equals(Types.BOOLEAN)) {
                strArr2[i] = "bool";
            } else {
                if (!typeInformation.equals(Types.BIG_DEC)) {
                    throw new AkIllegalArgumentException("Linear algorithm only support numerical data type. Current type is : " + typeInformation);
                }
                strArr2[i] = "decimal";
            }
        }
        return strArr2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static DataSet<Tuple3<Double, Double, Vector>> preProcess(DataSet<Tuple3<Double, Object, Vector>> dataSet, Params params, final boolean z, DataSet<DenseVector[]> dataSet2, DataSet<Object[]> dataSet3, DataSet<Integer> dataSet4) {
        final boolean booleanValue = ((Boolean) params.get(LinearTrainParams.STANDARDIZATION)).booleanValue();
        final boolean booleanValue2 = ((Boolean) params.get(LinearTrainParams.WITH_INTERCEPT)).booleanValue();
        return dataSet.mapPartition(new RichMapPartitionFunction<Tuple3<Double, Object, Vector>, Tuple3<Double, Double, Vector>>() { // from class: com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp.15
            private static final long serialVersionUID = -3931917328901089041L;
            private DenseVector[] meanVar;
            private Object[] labelValues = null;
            private int featureSize;

            public void open(Configuration configuration) {
                this.meanVar = (DenseVector[]) getRuntimeContext().getBroadcastVariable(BaseLinearModelTrainBatchOp.MEAN_VAR).get(0);
                this.labelValues = (Object[]) getRuntimeContext().getBroadcastVariable("labelValues").get(0);
                this.featureSize = ((Integer) getRuntimeContext().getBroadcastVariable("featureSize").get(0)).intValue();
                BaseLinearModelTrainBatchOp.modifyMeanVar(booleanValue, this.meanVar);
            }

            public void mapPartition(Iterable<Tuple3<Double, Object, Vector>> iterable, Collector<Tuple3<Double, Double, Vector>> collector) {
                for (Tuple3<Double, Object, Vector> tuple3 : iterable) {
                    Vector vector = (Vector) tuple3.f2;
                    if (((Double) tuple3.f0).doubleValue() > Criteria.INVALID_GAIN) {
                        Double valueOf = Double.valueOf(z ? Double.parseDouble(tuple3.f1.toString()) : tuple3.f1.equals(this.labelValues[0]) ? 1.0d : -1.0d);
                        if (vector instanceof DenseVector) {
                            if (vector.size() < this.featureSize) {
                                DenseVector denseVector = new DenseVector(this.featureSize);
                                for (int i = 0; i < vector.size(); i++) {
                                    denseVector.set(i, vector.get(i));
                                }
                                vector = denseVector;
                            }
                            if (booleanValue) {
                                if (booleanValue2) {
                                    for (int i2 = 0; i2 < vector.size(); i2++) {
                                        vector.set(i2, (vector.get(i2) - this.meanVar[0].get(i2)) / this.meanVar[1].get(i2));
                                    }
                                } else {
                                    for (int i3 = 0; i3 < vector.size(); i3++) {
                                        vector.set(i3, vector.get(i3) / this.meanVar[1].get(i3));
                                    }
                                }
                            }
                        } else {
                            if (booleanValue) {
                                int[] indices = ((SparseVector) vector).getIndices();
                                double[] values = ((SparseVector) vector).getValues();
                                for (int i4 = 0; i4 < indices.length; i4++) {
                                    values[i4] = values[i4] / this.meanVar[1].get(indices[i4]);
                                }
                            }
                            if (vector.size() == -1 || vector.size() == 0) {
                                ((SparseVector) vector).setSize(this.featureSize);
                            }
                        }
                        collector.collect(Tuple3.of(tuple3.f0, valueOf, vector));
                    }
                }
            }
        }).withBroadcastSet(dataSet2, MEAN_VAR).withBroadcastSet(dataSet3, "labelValues").withBroadcastSet(dataSet4, "featureSize");
    }

    private static boolean getIsRegProc(Params params, LinearModelType linearModelType, String str) {
        if (!linearModelType.equals(LinearModelType.LinearReg)) {
            if (!linearModelType.equals(LinearModelType.SVR)) {
                return false;
            }
            Double d = (Double) params.get(LinearSvrTrainParams.TAU);
            double doubleValue = ((Double) params.get(LinearSvrTrainParams.C)).doubleValue();
            if (d.doubleValue() < Criteria.INVALID_GAIN) {
                throw new AkIllegalArgumentException("Parameter tau must be positive number or zero!");
            }
            if (doubleValue <= Criteria.INVALID_GAIN) {
                throw new AkIllegalArgumentException("Parameter C must be positive number!");
            }
            params.set((ParamInfo<ParamInfo<Double>>) HasL2.L_2, (ParamInfo<Double>) Double.valueOf(1.0d / doubleValue));
            params.remove(LinearSvrTrainParams.C);
            return true;
        }
        if ("Ridge Regression".equals(str)) {
            double doubleValue2 = ((Double) params.get(RidgeRegTrainParams.LAMBDA)).doubleValue();
            AkPreconditions.checkState(doubleValue2 > Criteria.INVALID_GAIN, "Lambda must be positive number or zero! lambda is : " + doubleValue2);
            params.set((ParamInfo<ParamInfo<Double>>) HasL2.L_2, (ParamInfo<Double>) Double.valueOf(doubleValue2));
            params.remove(RidgeRegTrainParams.LAMBDA);
            return true;
        }
        if (!"LASSO".equals(str)) {
            return true;
        }
        double doubleValue3 = ((Double) params.get(LassoRegTrainParams.LAMBDA)).doubleValue();
        if (doubleValue3 < Criteria.INVALID_GAIN) {
            throw new AkIllegalArgumentException("Lambda must be positive number or zero!");
        }
        params.set((ParamInfo<ParamInfo<Double>>) HasL1.L_1, (ParamInfo<Double>) Double.valueOf(doubleValue3));
        params.remove(RidgeRegTrainParams.LAMBDA);
        return true;
    }

    public static LinearModelData buildLinearModelData(Params params, String[] strArr, TypeInformation<?> typeInformation, DenseVector[] denseVectorArr, boolean z, boolean z2, Tuple2<DenseVector, double[]> tuple2) {
        if (!LinearModelType.AFT.equals(params.get(ModelParamName.LINEAR_MODEL_TYPE))) {
            modifyMeanVar(z2, denseVectorArr);
        }
        params.set((ParamInfo<ParamInfo<Integer>>) ModelParamName.VECTOR_SIZE, (ParamInfo<Integer>) Integer.valueOf((((DenseVector) tuple2.f0).size() - (((Boolean) params.get(ModelParamName.HAS_INTERCEPT_ITEM)).booleanValue() ? 1 : 0)) - (LinearModelType.AFT.equals(params.get(ModelParamName.LINEAR_MODEL_TYPE)) ? 1 : 0)));
        if (!LinearModelType.AFT.equals(params.get(ModelParamName.LINEAR_MODEL_TYPE)) && z2) {
            int size = denseVectorArr[0].size();
            if (z) {
                double d = 0.0d;
                for (int i = 1; i < size; i++) {
                    d += (((DenseVector) tuple2.f0).get(i) * denseVectorArr[0].get(i)) / denseVectorArr[1].get(i);
                    ((DenseVector) tuple2.f0).set(i, ((DenseVector) tuple2.f0).get(i) / denseVectorArr[1].get(i));
                }
                ((DenseVector) tuple2.f0).set(0, ((DenseVector) tuple2.f0).get(0) - d);
            } else {
                for (int i2 = 0; i2 < size; i2++) {
                    ((DenseVector) tuple2.f0).set(i2, ((DenseVector) tuple2.f0).get(i2) / denseVectorArr[1].get(i2));
                }
            }
        }
        LinearModelData linearModelData = new LinearModelData(typeInformation, params, strArr, (DenseVector) tuple2.f0);
        linearModelData.labelName = (String) params.get(LinearTrainParams.LABEL_COL);
        linearModelData.featureTypes = (String[]) params.get(ModelParamName.FEATURE_TYPES);
        return linearModelData;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void modifyMeanVar(boolean z, DenseVector[] denseVectorArr) {
        if (z) {
            for (int i = 0; i < denseVectorArr[1].size(); i++) {
                if (denseVectorArr[1].get(i) == Criteria.INVALID_GAIN) {
                    denseVectorArr[1].set(i, 1.0d);
                    denseVectorArr[0].set(i, Criteria.INVALID_GAIN);
                }
            }
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.utils.WithTrainInfo
    public LinearModelTrainInfo createTrainInfo(List<Row> list) {
        return new LinearModelTrainInfo(list);
    }

    @Override // com.alibaba.alink.operator.batch.utils.WithTrainInfo
    public BatchOperator<?> getSideOutputTrainInfo() {
        return getSideOutput(0);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ BatchOperator linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }

    @Override // com.alibaba.alink.operator.batch.utils.WithTrainInfo
    public /* bridge */ /* synthetic */ LinearModelTrainInfo createTrainInfo(List list) {
        return createTrainInfo((List<Row>) list);
    }

    static {
        $assertionsDisabled = !BaseLinearModelTrainBatchOp.class.desiredAssertionStatus();
        compare = (tuple3, tuple32) -> {
            return ((Double) tuple32.f2).compareTo((Double) tuple3.f2);
        };
    }
}
