package com.alibaba.alink.operator.local.classification;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.annotation.FeatureColsVectorColMutexRule;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.Internal;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.lazy.WithTrainInfoLocalOp;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.RowCollector;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.linear.LinearModelData;
import com.alibaba.alink.operator.common.linear.LinearModelDataConverter;
import com.alibaba.alink.operator.common.linear.LinearModelTrainInfo;
import com.alibaba.alink.operator.common.linear.LinearModelType;
import com.alibaba.alink.operator.common.nlp.WordCountUtil;
import com.alibaba.alink.operator.common.optim.LocalOptimizer;
import com.alibaba.alink.operator.common.optim.activeSet.ConstraintVariable;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.local.LocalOperator;
import com.alibaba.alink.operator.local.classification.BaseLinearModelTrainLocalOp;
import com.alibaba.alink.params.regression.LassoRegTrainParams;
import com.alibaba.alink.params.regression.LinearSvrTrainParams;
import com.alibaba.alink.params.regression.RidgeRegTrainParams;
import com.alibaba.alink.params.shared.colname.HasFeatureCols;
import com.alibaba.alink.params.shared.colname.HasVectorCol;
import com.alibaba.alink.params.shared.linear.HasL1;
import com.alibaba.alink.params.shared.linear.HasL2;
import com.alibaba.alink.params.shared.linear.HasWithIntercept;
import com.alibaba.alink.params.shared.linear.LinearTrainParams;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.java.tuple.Tuple5;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

@InputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(value = PortType.MODEL, isOptional = true)})
@OutputPorts(values = {@PortSpec(PortType.MODEL), @PortSpec(value = PortType.DATA, desc = PortDesc.MODEL_INFO), @PortSpec(value = PortType.DATA, desc = PortDesc.FEATURE_IMPORTANCE), @PortSpec(value = PortType.DATA, desc = PortDesc.MODEL_WEIGHT)})
@FeatureColsVectorColMutexRule
@Internal
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "featureCols", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES}), @ParamSelectColumnSpec(name = "labelCol"), @ParamSelectColumnSpec(name = "weightCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES})})
/* loaded from: input_file:com/alibaba/alink/operator/local/classification/BaseLinearModelTrainLocalOp.class */
public abstract class BaseLinearModelTrainLocalOp<T extends BaseLinearModelTrainLocalOp<T>> extends LocalOperator<T> implements WithTrainInfoLocalOp<LinearModelTrainInfo, T> {
    static final int MAX_LABELS = 1000;
    static final double LABEL_RATIO = 0.5d;
    private final String modelName;
    private final LinearModelType linearModelType;
    public static Comparator<Tuple3<String, Double, Double>> compare = (tuple3, tuple32) -> {
        return ((Double) tuple32.f2).compareTo((Double) tuple3.f2);
    };

    public BaseLinearModelTrainLocalOp(Params params, LinearModelType linearModelType, String str) {
        super(params);
        this.modelName = str;
        this.linearModelType = linearModelType;
    }

    @Override // com.alibaba.alink.operator.local.LocalOperator
    public T linkFrom(LocalOperator<?>... localOperatorArr) {
        LocalOperator<?> localOperator;
        DenseVector initializeModelCoefs;
        LocalOperator<?> localOperator2 = null;
        if (localOperatorArr.length == 1) {
            localOperator = checkAndGetFirst(localOperatorArr);
        } else {
            localOperator = localOperatorArr[0];
            localOperator2 = localOperatorArr[1];
        }
        Params params = getParams();
        if (params.contains(HasFeatureCols.FEATURE_COLS) && params.contains(HasVectorCol.VECTOR_COL)) {
            throw new AkIllegalArgumentException("FeatureCols and vectorCol cannot be set at the same time.");
        }
        try {
            MTable outputTable = localOperator.getOutputTable();
            boolean isRegProc = getIsRegProc(params, this.linearModelType, this.modelName);
            boolean booleanValue = ((Boolean) params.get(LinearTrainParams.WITH_INTERCEPT)).booleanValue();
            TypeInformation<?> typeInformation = isRegProc ? Types.DOUBLE : outputTable.getColTypes()[TableUtil.findColIndexWithAssertAndHint(outputTable.getColNames(), (String) params.get(LinearTrainParams.LABEL_COL))];
            Tuple4<List<Tuple3<Double, Double, Vector>>, DenseVector[], Integer, Object[]> preprocess = preprocess(outputTable, params, isRegProc, this.linearModelType);
            List list = (List) preprocess.f0;
            DenseVector[] denseVectorArr = (DenseVector[]) preprocess.f1;
            Integer num = (Integer) preprocess.f2;
            Object[] objArr = (Object[]) preprocess.f3;
            if (null == localOperator2) {
                initializeModelCoefs = DenseVector.zeros(num.intValue() * (isRegProc ? 1 : objArr.length - 1));
            } else {
                initializeModelCoefs = initializeModelCoefs(localOperator2.getOutputTable().getRows(), num, denseVectorArr, params, this.linearModelType);
            }
            DenseVector denseVector = initializeModelCoefs;
            if (LinearModelType.Softmax == this.linearModelType) {
                params.set((ParamInfo<ParamInfo<Integer>>) ModelParamName.NUM_CLASSES, (ParamInfo<Integer>) Integer.valueOf(objArr.length));
            }
            Tuple2<DenseVector, double[]> optimize = LocalOptimizer.optimize(OptimObjFunc.getObjFunction(this.linearModelType, params), list, denseVector, params);
            String[] strArr = (String[]) params.get(LinearTrainParams.FEATURE_COLS);
            Params params2 = new Params();
            params2.set((ParamInfo<ParamInfo<String>>) ModelParamName.MODEL_NAME, (ParamInfo<String>) this.modelName);
            params2.set((ParamInfo<ParamInfo<LinearModelType>>) ModelParamName.LINEAR_MODEL_TYPE, (ParamInfo<LinearModelType>) this.linearModelType);
            params2.set((ParamInfo<ParamInfo<Boolean>>) ModelParamName.HAS_INTERCEPT_ITEM, (ParamInfo<Boolean>) Boolean.valueOf(booleanValue));
            params2.set((ParamInfo<ParamInfo<String>>) ModelParamName.VECTOR_COL_NAME, (ParamInfo<String>) params.get(LinearTrainParams.VECTOR_COL));
            params2.set((ParamInfo<ParamInfo<String>>) LinearTrainParams.LABEL_COL, (ParamInfo<String>) params.get(LinearTrainParams.LABEL_COL));
            params2.set((ParamInfo<ParamInfo<String[]>>) ModelParamName.FEATURE_TYPES, (ParamInfo<String[]>) getFeatureTypes(outputTable.getSchema(), strArr));
            if (LinearModelType.LinearReg != this.linearModelType && LinearModelType.SVR != this.linearModelType && LinearModelType.AFT != this.linearModelType) {
                params2.set((ParamInfo<ParamInfo<Object[]>>) ModelParamName.LABEL_VALUES, (ParamInfo<Object[]>) objArr);
            }
            if (LinearModelType.Softmax == this.linearModelType) {
                params2.set((ParamInfo<ParamInfo<Integer>>) ModelParamName.NUM_CLASSES, (ParamInfo<Integer>) Integer.valueOf(objArr.length));
            }
            if (LinearModelType.AFT.equals(params2.get(ModelParamName.LINEAR_MODEL_TYPE))) {
                denseVectorArr = null;
            }
            LinearModelData buildLinearModelData = buildLinearModelData(params2, strArr, typeInformation, denseVectorArr, booleanValue, ((Boolean) params.get(LinearTrainParams.STANDARDIZATION)).booleanValue(), optimize);
            LinearModelDataConverter linearModelDataConverter = new LinearModelDataConverter(typeInformation);
            RowCollector rowCollector = new RowCollector();
            linearModelDataConverter.save(buildLinearModelData, rowCollector);
            List<Row> rows = rowCollector.getRows();
            setOutputTable(new MTable(rows, linearModelDataConverter.getModelSchema()));
            setSideOutputTables(getSideTablesOfCoefficient(rows, (double[]) optimize.f1, list, num, (String[]) params.get(LinearTrainParams.FEATURE_COLS), ((Boolean) params.get(LinearTrainParams.WITH_INTERCEPT)).booleanValue()));
            return this;
        } catch (ExceptionWithErrorCode e) {
            throw e;
        } catch (Exception e2) {
            e2.printStackTrace();
            throw new AkUnclassifiedErrorException("Error. ", e2);
        }
    }

    public static DenseVector initializeModelCoefs(List<Row> list, Integer num, DenseVector[] denseVectorArr, Params params, LinearModelType linearModelType) {
        LinearModelData load = new LinearModelDataConverter().load(list);
        if (load.hasInterceptItem != ((Boolean) params.get(HasWithIntercept.WITH_INTERCEPT)).booleanValue()) {
            throw new AkIllegalArgumentException("Initial linear model is not compatible with parameter setting.InterceptItem parameter setting error.");
        }
        if (load.linearModelType != linearModelType) {
            throw new AkIllegalArgumentException("Initial linear model is not compatible with parameter setting.linearModelType setting error.");
        }
        if (load.vectorSize != num.intValue()) {
            throw new AkIllegalDataException("Initial linear model is not compatible with training data.  vector size not equal, vector size in init model is : " + load.vectorSize + " and vector size of train data is : " + num);
        }
        int size = denseVectorArr[0].size();
        if (LinearModelType.Softmax.equals(linearModelType)) {
            if (load.hasInterceptItem) {
                for (int i = 0; i < 9; i++) {
                    double d = 0.0d;
                    for (int i2 = 1; i2 < size; i2++) {
                        int i3 = (i * size) + i2;
                        d += load.coefVector.get(i3) * denseVectorArr[0].get(i2);
                        load.coefVector.set(i3, load.coefVector.get(i3) * denseVectorArr[1].get(i2));
                    }
                    load.coefVector.set(i * size, load.coefVector.get(i * size) + d);
                }
            } else {
                for (int i4 = 0; i4 < load.coefVector.size(); i4++) {
                    load.coefVector.set(i4, load.coefVector.get(i4) * denseVectorArr[1].get(i4 % denseVectorArr[1].size()));
                }
            }
        } else if (load.hasInterceptItem) {
            double d2 = 0.0d;
            for (int i5 = 1; i5 < size; i5++) {
                d2 += load.coefVector.get(i5) * denseVectorArr[0].get(i5);
                load.coefVector.set(i5, load.coefVector.get(i5) * denseVectorArr[1].get(i5));
            }
            load.coefVector.set(0, load.coefVector.get(0) + d2);
        } else {
            for (int i6 = 0; i6 < size; i6++) {
                load.coefVector.set(i6, load.coefVector.get(i6) * denseVectorArr[1].get(i6));
            }
        }
        return load.coefVector;
    }

    public static MTable[] getSideTablesOfCoefficient(List<Row> list, double[] dArr, List<Tuple3<Double, Double, Vector>> list2, Integer num, String[] strArr, boolean z) {
        String[] strArr2;
        LinearModelData load = new LinearModelDataConverter().load(list);
        int intValue = num.intValue();
        if (z) {
            intValue--;
        }
        int i = 0;
        double[] dArr2 = new double[intValue];
        double[] dArr3 = new double[intValue];
        if (strArr == null) {
            for (Tuple3<Double, Double, Vector> tuple3 : list2) {
                if (((Double) tuple3.f0).doubleValue() >= Criteria.INVALID_GAIN) {
                    if (tuple3.f2 instanceof SparseVector) {
                        SparseVector sparseVector = (SparseVector) tuple3.f2;
                        sparseVector.setSize(intValue);
                        double[] values = sparseVector.getValues();
                        int[] indices = sparseVector.getIndices();
                        for (int i2 = 0; i2 < values.length; i2++) {
                            if (!z) {
                                int i3 = indices[i2];
                                dArr2[i3] = dArr2[i3] + values[i2];
                                int i4 = indices[i2];
                                dArr3[i4] = dArr3[i4] + (values[i2] * values[i2]);
                            } else if (indices[i2] > 0) {
                                int i5 = indices[i2] - 1;
                                dArr2[i5] = dArr2[i5] + values[i2];
                                int i6 = indices[i2] - 1;
                                dArr3[i6] = dArr3[i6] + (values[i2] * values[i2]);
                            }
                        }
                        i++;
                    } else {
                        for (int i7 = 0; i7 < intValue; i7++) {
                            double d = ((Vector) tuple3.f2).get(i7 + (z ? 1 : 0));
                            int i8 = i7;
                            dArr2[i8] = dArr2[i8] + d;
                            int i9 = i7;
                            dArr3[i9] = dArr3[i9] + (d * d);
                        }
                        i++;
                    }
                }
            }
        } else {
            for (Tuple3<Double, Double, Vector> tuple32 : list2) {
                if (((Double) tuple32.f0).doubleValue() >= Criteria.INVALID_GAIN) {
                    for (int i10 = 0; i10 < intValue; i10++) {
                        double d2 = ((Vector) tuple32.f2).get(i10 + (z ? 1 : 0));
                        int i11 = i10;
                        dArr2[i11] = dArr2[i11] + d2;
                        int i12 = i10;
                        dArr3[i12] = dArr3[i12] + (d2 * d2);
                    }
                    i++;
                }
            }
        }
        Tuple3 of = Tuple3.of(Integer.valueOf(i), dArr2, dArr3);
        DenseVector denseVector = load.coefVector;
        if (strArr == null) {
            strArr2 = new String[denseVector.size() - (z ? 1 : 0)];
            for (int i13 = 0; i13 < strArr2.length; i13++) {
                strArr2[i13] = String.valueOf(i13);
            }
        } else {
            strArr2 = strArr;
        }
        double[] dArr4 = z ? new double[denseVector.size() - 1] : new double[denseVector.size()];
        for (int i14 = 0; i14 < ((double[]) of.f1).length; i14++) {
            double intValue2 = ((double[]) of.f1)[i14] / ((Integer) of.f0).intValue();
            dArr4[i14] = Math.abs(denseVector.get(i14 + (z ? 1 : 0)) * (((Integer) of.f0).intValue() == 1 ? 0.0d : Math.sqrt(Math.max(Criteria.INVALID_GAIN, ((double[]) of.f2)[i14] - ((((Integer) of.f0).intValue() * intValue2) * intValue2)) / (((Integer) of.f0).intValue() - 1))));
        }
        Tuple5 of2 = Tuple5.of(JsonConverter.toJson(load.getMetaInfo()), strArr2, denseVector.getData(), dArr4, dArr);
        ArrayList arrayList = new ArrayList();
        for (int i15 = 0; i15 < ((String[]) Objects.requireNonNull(strArr2)).length; i15++) {
            arrayList.add(Row.of(new Object[]{strArr2[i15], Double.valueOf(dArr4[i15])}));
        }
        ArrayList arrayList2 = new ArrayList();
        double[] dArr5 = (double[]) of2.f2;
        if (dArr5.length == strArr2.length) {
            for (int i16 = 0; i16 < strArr2.length; i16++) {
                arrayList2.add(Row.of(new Object[]{strArr2[i16], Double.valueOf(dArr5[i16])}));
            }
        } else {
            arrayList2.add(Row.of(new Object[]{"_intercept_", Double.valueOf(dArr5[0])}));
            for (int i17 = 0; i17 < strArr2.length; i17++) {
                arrayList2.add(Row.of(new Object[]{strArr2[i17], Double.valueOf(dArr5[i17 + 1])}));
            }
        }
        ArrayList arrayList3 = new ArrayList();
        if (((String[]) of2.f1).length < 10000) {
            arrayList3.add(Row.of(new Object[]{0L, of2.f0}));
            arrayList3.add(Row.of(new Object[]{1L, JsonConverter.toJson(of2.f1)}));
            arrayList3.add(Row.of(new Object[]{2L, JsonConverter.toJson(of2.f2)}));
            arrayList3.add(Row.of(new Object[]{3L, JsonConverter.toJson(of2.f3)}));
            arrayList3.add(Row.of(new Object[]{4L, JsonConverter.toJson(of2.f4)}));
        } else {
            ArrayList arrayList4 = new ArrayList(((String[]) of2.f1).length);
            int i18 = z ? 1 : 0;
            for (int i19 = 0; i19 < ((String[]) of2.f1).length; i19++) {
                arrayList4.add(Tuple3.of(((String[]) of2.f1)[i19], Double.valueOf(((double[]) of2.f2)[i19 + i18]), Double.valueOf(((double[]) of2.f3)[i19])));
            }
            arrayList4.sort(compare);
            String[] strArr3 = new String[WordCountUtil.BOUND_SIZE];
            double[] dArr6 = new double[WordCountUtil.BOUND_SIZE];
            double[] dArr7 = new double[WordCountUtil.BOUND_SIZE];
            for (int i20 = 0; i20 < 5000; i20++) {
                strArr3[i20] = (String) ((Tuple3) arrayList4.get(i20)).f0;
                dArr6[i20] = ((Double) ((Tuple3) arrayList4.get(i20)).f1).doubleValue();
                dArr7[i20] = ((Double) ((Tuple3) arrayList4.get(i20)).f2).doubleValue();
                int length = (((String[]) of2.f1).length - i20) - 1;
                int i21 = (WordCountUtil.BOUND_SIZE - i20) - 1;
                strArr3[i21] = (String) ((Tuple3) arrayList4.get(length)).f0;
                dArr6[i21] = ((Double) ((Tuple3) arrayList4.get(length)).f1).doubleValue();
                dArr7[i21] = ((Double) ((Tuple3) arrayList4.get(length)).f2).doubleValue();
            }
            arrayList3.add(Row.of(new Object[]{0L, of2.f0}));
            arrayList3.add(Row.of(new Object[]{1L, JsonConverter.toJson(strArr3)}));
            arrayList3.add(Row.of(new Object[]{2L, JsonConverter.toJson(dArr6)}));
            arrayList3.add(Row.of(new Object[]{3L, JsonConverter.toJson(dArr7)}));
            arrayList3.add(Row.of(new Object[]{4L, JsonConverter.toJson(of2.f4)}));
        }
        return new MTable[]{new MTable(arrayList3, new TableSchema(new String[]{"id", "info"}, new TypeInformation[]{Types.LONG, Types.STRING})), new MTable(arrayList, new TableSchema(new String[]{"col_name", "importance"}, new TypeInformation[]{Types.STRING, Types.DOUBLE})), new MTable(arrayList2, new TableSchema(new String[]{"col_name", ConstraintVariable.weight}, new TypeInformation[]{Types.STRING, Types.DOUBLE}))};
    }

    private static Object[] orderLabels(Iterable<Object> iterable) {
        ArrayList arrayList = new ArrayList();
        Iterator<Object> it = iterable.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next());
        }
        Object[] array = arrayList.toArray(new Object[0]);
        String obj = array[0].toString();
        String obj2 = array[1].toString();
        if (array[1].toString().equals(obj2.compareTo(obj) > 0 ? obj2 : obj)) {
            Object obj3 = array[0];
            array[0] = array[1];
            array[1] = obj3;
        }
        return array;
    }

    private static Tuple4<List<Tuple3<Double, Double, Vector>>, DenseVector[], Integer, Object[]> preprocess(MTable mTable, Params params, boolean z, LinearModelType linearModelType) {
        boolean z2;
        boolean booleanValue = ((Boolean) params.get(LinearTrainParams.STANDARDIZATION)).booleanValue();
        boolean booleanValue2 = ((Boolean) params.get(LinearTrainParams.WITH_INTERCEPT)).booleanValue();
        String[] strArr = (String[]) params.get(LinearTrainParams.FEATURE_COLS);
        String str = (String) params.get(LinearTrainParams.LABEL_COL);
        String str2 = (String) params.get(LinearTrainParams.WEIGHT_COL);
        String str3 = (String) params.get(LinearTrainParams.VECTOR_COL);
        TableSchema schema = mTable.getSchema();
        if (null == strArr && null == str3) {
            strArr = TableUtil.getNumericCols(schema, new String[]{str});
            params.set((ParamInfo<ParamInfo<String[]>>) LinearTrainParams.FEATURE_COLS, (ParamInfo<String[]>) strArr);
        }
        int[] iArr = null;
        int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(schema.getFieldNames(), str);
        if (strArr != null) {
            iArr = new int[strArr.length];
            for (int i = 0; i < strArr.length; i++) {
                int findColIndexWithAssertAndHint2 = TableUtil.findColIndexWithAssertAndHint(mTable.getColNames(), strArr[i]);
                iArr[i] = findColIndexWithAssertAndHint2;
                TypeInformation typeInformation = mTable.getSchema().getFieldTypes()[findColIndexWithAssertAndHint2];
                AkPreconditions.checkState(TableUtil.isSupportedNumericType(typeInformation), "linear algorithm only support numerical data type. Current type is : " + typeInformation);
            }
        }
        int findColIndexWithAssertAndHint3 = str2 != null ? TableUtil.findColIndexWithAssertAndHint(mTable.getColNames(), str2) : -1;
        int findColIndexWithAssertAndHint4 = str3 != null ? TableUtil.findColIndexWithAssertAndHint(mTable.getColNames(), str3) : -1;
        HashSet hashSet = new HashSet();
        ArrayList<Tuple3> arrayList = new ArrayList();
        ArrayList<Tuple3> arrayList2 = new ArrayList();
        int i2 = -1;
        int i3 = -1;
        if (iArr != null) {
            z2 = false;
            i3 = booleanValue2 ? iArr.length + 1 : iArr.length;
            int i4 = booleanValue2 ? 1 : 0;
            for (Row row : mTable.getRows()) {
                Double valueOf = Double.valueOf(findColIndexWithAssertAndHint3 != -1 ? ((Number) row.getField(findColIndexWithAssertAndHint3)).doubleValue() : 1.0d);
                Object field = row.getField(findColIndexWithAssertAndHint);
                if (null != field) {
                    DenseVector denseVector = new DenseVector(i3);
                    if (booleanValue2) {
                        denseVector.set(0, 1.0d);
                    }
                    boolean z3 = true;
                    int i5 = 0;
                    while (true) {
                        if (i5 >= iArr.length) {
                            break;
                        }
                        Object field2 = row.getField(iArr[i5]);
                        if (field2 == null) {
                            z3 = false;
                            break;
                        }
                        denseVector.set(i5 + i4, ((Number) field2).doubleValue());
                        i5++;
                    }
                    if (z3) {
                        arrayList.add(Tuple3.of(valueOf, field, denseVector));
                    }
                }
            }
        } else {
            int i6 = 0;
            int i7 = 0;
            ArrayList<Tuple3> arrayList3 = new ArrayList();
            for (Row row2 : mTable.getRows()) {
                Double valueOf2 = Double.valueOf(findColIndexWithAssertAndHint3 != -1 ? ((Number) row2.getField(findColIndexWithAssertAndHint3)).doubleValue() : 1.0d);
                Object field3 = row2.getField(findColIndexWithAssertAndHint);
                Vector vector = VectorUtil.getVector(row2.getField(findColIndexWithAssertAndHint4));
                if (null != field3 && null != vector) {
                    arrayList3.add(Tuple3.of(valueOf2, field3, booleanValue2 ? vector.prefix(1.0d) : vector));
                }
            }
            Iterator it = arrayList3.iterator();
            while (it.hasNext()) {
                Vector vector2 = (Vector) ((Tuple3) it.next()).f2;
                if (vector2 instanceof SparseVector) {
                    i6++;
                    int[] indices = ((SparseVector) vector2).getIndices();
                    if (indices.length > 0) {
                        i2 = Math.max(indices[indices.length - 1] + 1, i2);
                    }
                } else {
                    i7++;
                    int size = ((DenseVector) vector2).size();
                    if (i3 < 0) {
                        i3 = size;
                    } else {
                        AkPreconditions.checkState(size == i3, "Vector for linear model train have different dimension, please check your input data.");
                    }
                }
            }
            if (i6 > 2 * i7) {
                z2 = true;
                for (Tuple3 tuple3 : arrayList3) {
                    SparseVector sparseVector = VectorUtil.getSparseVector(tuple3.f2);
                    sparseVector.setSize(i2);
                    arrayList2.add(Tuple3.of(tuple3.f0, tuple3.f1, sparseVector));
                }
            } else {
                z2 = false;
                for (Tuple3 tuple32 : arrayList3) {
                    Vector vector3 = (Vector) tuple32.f2;
                    if (vector3 instanceof DenseVector) {
                        arrayList.add(Tuple3.of(tuple32.f0, tuple32.f1, (DenseVector) vector3));
                    } else {
                        SparseVector sparseVector2 = (SparseVector) vector3;
                        sparseVector2.setSize(i3);
                        arrayList.add(Tuple3.of(tuple32.f0, tuple32.f1, sparseVector2.toDenseVector()));
                    }
                }
            }
        }
        DenseVector[] denseVectorArr = new DenseVector[2];
        if (booleanValue) {
            if (z2) {
                double[] dArr = new double[i2];
                Iterator it2 = arrayList2.iterator();
                while (it2.hasNext()) {
                    SparseVector sparseVector3 = (SparseVector) ((Tuple3) it2.next()).f2;
                    int[] indices2 = sparseVector3.getIndices();
                    double[] values = sparseVector3.getValues();
                    for (int i8 = 0; i8 < indices2.length; i8++) {
                        dArr[indices2[i8]] = Math.max(dArr[indices2[i8]], Math.abs(values[i8]));
                    }
                }
                for (int i9 = 0; i9 < i2; i9++) {
                    if (dArr[i9] <= Criteria.INVALID_GAIN) {
                        dArr[i9] = 1.0d;
                    }
                }
                Iterator it3 = arrayList2.iterator();
                while (it3.hasNext()) {
                    SparseVector sparseVector4 = (SparseVector) ((Tuple3) it3.next()).f2;
                    int[] indices3 = sparseVector4.getIndices();
                    double[] values2 = sparseVector4.getValues();
                    for (int i10 = 0; i10 < indices3.length; i10++) {
                        int i11 = i10;
                        values2[i11] = values2[i11] / dArr[indices3[i10]];
                    }
                }
                denseVectorArr[0] = new DenseVector(i2);
                denseVectorArr[1] = new DenseVector(dArr);
            } else {
                double[] dArr2 = new double[i3];
                double[] dArr3 = new double[i3];
                Iterator it4 = arrayList.iterator();
                while (it4.hasNext()) {
                    double[] data = ((DenseVector) ((Tuple3) it4.next()).f2).getData();
                    for (int i12 = 0; i12 < i3; i12++) {
                        int i13 = i12;
                        dArr2[i13] = dArr2[i13] + data[i12];
                        int i14 = i12;
                        dArr3[i14] = dArr3[i14] + (data[i12] * data[i12]);
                    }
                }
                int size2 = arrayList.size();
                for (int i15 = 0; i15 < i3; i15++) {
                    int i16 = i15;
                    dArr2[i16] = dArr2[i16] / size2;
                    dArr3[i15] = size2 <= 1 ? 1.0d : Math.sqrt(Math.max(Criteria.INVALID_GAIN, (dArr3[i15] - ((size2 * dArr2[i15]) * dArr2[i15])) / (size2 - 1)));
                    if (Criteria.INVALID_GAIN == dArr3[i15]) {
                        dArr3[i15] = Criteria.INVALID_GAIN == dArr2[i15] ? 1.0d : dArr2[i15];
                        dArr2[i15] = 0.0d;
                    }
                }
                Iterator it5 = arrayList.iterator();
                while (it5.hasNext()) {
                    double[] data2 = ((DenseVector) ((Tuple3) it5.next()).f2).getData();
                    for (int i17 = 0; i17 < i3; i17++) {
                        data2[i17] = (data2[i17] - dArr2[i17]) / dArr3[i17];
                    }
                }
                denseVectorArr[0] = new DenseVector(dArr2);
                denseVectorArr[1] = new DenseVector(dArr3);
            }
        }
        HashMap hashMap = new HashMap();
        Object[] objArr = new Object[0];
        if (!z) {
            if (z2) {
                Iterator it6 = arrayList2.iterator();
                while (it6.hasNext()) {
                    hashSet.add(((Tuple3) it6.next()).f1);
                }
            } else {
                Iterator it7 = arrayList.iterator();
                while (it7.hasNext()) {
                    hashSet.add(((Tuple3) it7.next()).f1);
                }
            }
            objArr = orderLabels(hashSet);
            if (LinearModelType.Softmax == linearModelType) {
                for (int i18 = 0; i18 < objArr.length; i18++) {
                    hashMap.put(objArr[i18], Double.valueOf(i18));
                }
            } else {
                if (objArr.length != 2) {
                    StringBuilder sb = new StringBuilder();
                    for (int i19 = 0; i19 < Math.min(objArr.length, 10); i19++) {
                        sb.append(objArr[i19]);
                        if (i19 > 0) {
                            sb.append(",");
                        }
                    }
                    if (objArr.length > 10) {
                        sb.append(", ...... ");
                    }
                    throw new AkIllegalDataException(linearModelType + " need 2 label values, but training data's distinct label values : " + ((Object) sb));
                }
                hashMap.put(objArr[0], Double.valueOf(1.0d));
                hashMap.put(objArr[1], Double.valueOf(-1.0d));
            }
        }
        if (hashSet.size() > mTable.getNumRow() * 0.5d && hashSet.size() > 1000) {
            throw new AkIllegalDataException("label num is : " + hashSet.size() + ", sample num is : " + mTable.getNumRow() + ", please check your label column.");
        }
        ArrayList arrayList4 = new ArrayList();
        if (z2) {
            for (Tuple3 tuple33 : arrayList2) {
                arrayList4.add(Tuple3.of(tuple33.f0, z ? Double.valueOf(Double.parseDouble(tuple33.f1.toString())) : (Double) hashMap.get(tuple33.f1), tuple33.f2));
            }
            return Tuple4.of(arrayList4, denseVectorArr, Integer.valueOf(i2), objArr);
        }
        for (Tuple3 tuple34 : arrayList) {
            arrayList4.add(Tuple3.of(tuple34.f0, z ? Double.valueOf(Double.parseDouble(tuple34.f1.toString())) : (Double) hashMap.get(tuple34.f1), tuple34.f2));
        }
        return Tuple4.of(arrayList4, denseVectorArr, Integer.valueOf(i3), objArr);
    }

    protected static String[] getFeatureTypes(BatchOperator<?> batchOperator, String[] strArr) {
        return getFeatureTypes(batchOperator.getSchema(), strArr);
    }

    protected static String[] getFeatureTypes(TableSchema tableSchema, String[] strArr) {
        if (strArr == null) {
            return null;
        }
        String[] strArr2 = new String[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            TypeInformation typeInformation = tableSchema.getFieldTypes()[TableUtil.findColIndexWithAssertAndHint(tableSchema.getFieldNames(), strArr[i])];
            if (typeInformation.equals(Types.DOUBLE)) {
                strArr2[i] = "double";
            } else if (typeInformation.equals(Types.FLOAT)) {
                strArr2[i] = "float";
            } else if (typeInformation.equals(Types.LONG)) {
                strArr2[i] = "long";
            } else if (typeInformation.equals(Types.INT)) {
                strArr2[i] = "int";
            } else if (typeInformation.equals(Types.SHORT)) {
                strArr2[i] = "short";
            } else {
                if (!typeInformation.equals(Types.BOOLEAN)) {
                    throw new AkIllegalArgumentException("Linear algorithm only support numerical data type. Current type is : " + typeInformation);
                }
                strArr2[i] = "bool";
            }
        }
        return strArr2;
    }

    private static boolean getIsRegProc(Params params, LinearModelType linearModelType, String str) {
        if (!linearModelType.equals(LinearModelType.LinearReg)) {
            if (!linearModelType.equals(LinearModelType.SVR)) {
                return false;
            }
            Double d = (Double) params.get(LinearSvrTrainParams.TAU);
            double doubleValue = ((Double) params.get(LinearSvrTrainParams.C)).doubleValue();
            if (d.doubleValue() < Criteria.INVALID_GAIN) {
                throw new AkIllegalArgumentException("Parameter tau must be positive number or zero!");
            }
            if (doubleValue <= Criteria.INVALID_GAIN) {
                throw new AkIllegalArgumentException("Parameter C must be positive number!");
            }
            params.set((ParamInfo<ParamInfo<Double>>) HasL2.L_2, (ParamInfo<Double>) Double.valueOf(1.0d / doubleValue));
            params.remove(LinearSvrTrainParams.C);
            return true;
        }
        if ("Ridge Regression".equals(str)) {
            double doubleValue2 = ((Double) params.get(RidgeRegTrainParams.LAMBDA)).doubleValue();
            AkPreconditions.checkState(doubleValue2 > Criteria.INVALID_GAIN, "Lambda must be positive number or zero! lambda is : " + doubleValue2);
            params.set((ParamInfo<ParamInfo<Double>>) HasL2.L_2, (ParamInfo<Double>) Double.valueOf(doubleValue2));
            params.remove(RidgeRegTrainParams.LAMBDA);
            return true;
        }
        if (!"LASSO".equals(str)) {
            return true;
        }
        double doubleValue3 = ((Double) params.get(LassoRegTrainParams.LAMBDA)).doubleValue();
        if (doubleValue3 < Criteria.INVALID_GAIN) {
            throw new AkIllegalArgumentException("Lambda must be positive number or zero!");
        }
        params.set((ParamInfo<ParamInfo<Double>>) HasL1.L_1, (ParamInfo<Double>) Double.valueOf(doubleValue3));
        params.remove(RidgeRegTrainParams.LAMBDA);
        return true;
    }

    public static LinearModelData buildLinearModelData(Params params, String[] strArr, TypeInformation<?> typeInformation, DenseVector[] denseVectorArr, boolean z, boolean z2, Tuple2<DenseVector, double[]> tuple2) {
        int intValue = LinearModelType.Softmax == params.get(ModelParamName.LINEAR_MODEL_TYPE) ? ((Integer) params.get(ModelParamName.NUM_CLASSES)).intValue() - 1 : 1;
        if (!LinearModelType.AFT.equals(params.get(ModelParamName.LINEAR_MODEL_TYPE))) {
            modifyMeanVar(z2, denseVectorArr);
        }
        params.set((ParamInfo<ParamInfo<Integer>>) ModelParamName.VECTOR_SIZE, (ParamInfo<Integer>) Integer.valueOf(((((DenseVector) tuple2.f0).size() / intValue) - (((Boolean) params.get(ModelParamName.HAS_INTERCEPT_ITEM)).booleanValue() ? 1 : 0)) - (LinearModelType.AFT.equals(params.get(ModelParamName.LINEAR_MODEL_TYPE)) ? 1 : 0)));
        if (!LinearModelType.AFT.equals(params.get(ModelParamName.LINEAR_MODEL_TYPE))) {
            if (LinearModelType.Softmax.equals(params.get(ModelParamName.LINEAR_MODEL_TYPE))) {
                if (z) {
                    int size = denseVectorArr[0].size();
                    for (int i = 0; i < intValue; i++) {
                        double d = 0.0d;
                        for (int i2 = 1; i2 < size; i2++) {
                            int i3 = (i * size) + i2;
                            d += (((DenseVector) tuple2.f0).get(i3) * denseVectorArr[0].get(i2)) / denseVectorArr[1].get(i2);
                            ((DenseVector) tuple2.f0).set(i3, ((DenseVector) tuple2.f0).get(i3) / denseVectorArr[1].get(i2));
                        }
                        ((DenseVector) tuple2.f0).set(i * size, ((DenseVector) tuple2.f0).get(i * size) - d);
                    }
                } else {
                    for (int i4 = 0; i4 < ((DenseVector) tuple2.f0).size(); i4++) {
                        ((DenseVector) tuple2.f0).set(i4, ((DenseVector) tuple2.f0).get(i4) / denseVectorArr[1].get(i4 % denseVectorArr[1].size()));
                    }
                }
            } else if (z2) {
                int size2 = denseVectorArr[0].size();
                if (z) {
                    double d2 = 0.0d;
                    for (int i5 = 1; i5 < size2; i5++) {
                        d2 += (((DenseVector) tuple2.f0).get(i5) * denseVectorArr[0].get(i5)) / denseVectorArr[1].get(i5);
                        ((DenseVector) tuple2.f0).set(i5, ((DenseVector) tuple2.f0).get(i5) / denseVectorArr[1].get(i5));
                    }
                    ((DenseVector) tuple2.f0).set(0, ((DenseVector) tuple2.f0).get(0) - d2);
                } else {
                    for (int i6 = 0; i6 < size2; i6++) {
                        ((DenseVector) tuple2.f0).set(i6, ((DenseVector) tuple2.f0).get(i6) / denseVectorArr[1].get(i6));
                    }
                }
            }
        }
        LinearModelData linearModelData = new LinearModelData(typeInformation, params, strArr, (DenseVector) tuple2.f0);
        linearModelData.labelName = (String) params.get(LinearTrainParams.LABEL_COL);
        linearModelData.featureTypes = (String[]) params.get(ModelParamName.FEATURE_TYPES);
        return linearModelData;
    }

    private static void modifyMeanVar(boolean z, DenseVector[] denseVectorArr) {
        if (z) {
            for (int i = 0; i < denseVectorArr[1].size(); i++) {
                if (denseVectorArr[1].get(i) == Criteria.INVALID_GAIN) {
                    denseVectorArr[1].set(i, 1.0d);
                    denseVectorArr[0].set(i, Criteria.INVALID_GAIN);
                }
            }
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.common.lazy.WithTrainInfoLocalOp
    public LinearModelTrainInfo createTrainInfo(List<Row> list) {
        return new LinearModelTrainInfo(list);
    }

    @Override // com.alibaba.alink.common.lazy.WithTrainInfoLocalOp
    public LocalOperator<?> getSideOutputTrainInfo() {
        return getSideOutput(0);
    }

    @Override // com.alibaba.alink.operator.local.LocalOperator
    public /* bridge */ /* synthetic */ LocalOperator linkFrom(LocalOperator[] localOperatorArr) {
        return linkFrom((LocalOperator<?>[]) localOperatorArr);
    }

    @Override // com.alibaba.alink.common.lazy.WithTrainInfoLocalOp
    public /* bridge */ /* synthetic */ LinearModelTrainInfo createTrainInfo(List list) {
        return createTrainInfo((List<Row>) list);
    }
}
