package com.alibaba.alink.operator.common.linear;

import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.TableSourceBatchOp;
import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.finance.stepwiseSelector.BaseStepWiseSelectorBatchOp;
import com.alibaba.alink.operator.common.finance.stepwiseSelector.ClassificationSelectorResult;
import com.alibaba.alink.operator.common.finance.stepwiseSelector.RegressionSelectorResult;
import com.alibaba.alink.operator.common.finance.stepwiseSelector.SelectorResult;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummarizer;
import com.alibaba.alink.params.finance.HasConstrainedLinearModelType;
import java.security.InvalidParameterException;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:com/alibaba/alink/operator/common/linear/ModelSummaryHelper.class */
public class ModelSummaryHelper {

    /* loaded from: input_file:com/alibaba/alink/operator/common/linear/ModelSummaryHelper$CalRegSummary.class */
    public static class CalRegSummary extends RichFlatMapFunction<BaseVectorSummarizer, LinearRegressionSummary> {
        private static final long serialVersionUID = 1372774780273725623L;
        private LinearModelData modelData;

        public void open(Configuration configuration) {
            this.modelData = (LinearModelData) getRuntimeContext().getBroadcastVariable("linearModelData").get(0);
        }

        public void flatMap(BaseVectorSummarizer baseVectorSummarizer, Collector<LinearRegressionSummary> collector) throws Exception {
            DenseVector denseVector = this.modelData.coefVector;
            int[] iArr = new int[baseVectorSummarizer.toSummary().vectorSize() - 1];
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = i + 1;
            }
            collector.collect(LocalLinearModel.calcLinearRegressionSummary(denseVector, baseVectorSummarizer, 0, iArr));
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((BaseVectorSummarizer) obj, (Collector<LinearRegressionSummary>) collector);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/linear/ModelSummaryHelper$CalcGradientAndHessian.class */
    public static class CalcGradientAndHessian extends RichMapPartitionFunction<Tuple3<Double, Double, Vector>, Tuple4<DenseVector, DenseVector, DenseMatrix, Double>> {
        private static final long serialVersionUID = 2861532185191763285L;
        private LinearModelData modelData;

        public void open(Configuration configuration) {
            this.modelData = (LinearModelData) getRuntimeContext().getBroadcastVariable("linearModelData").get(0);
        }

        public void mapPartition(Iterable<Tuple3<Double, Double, Vector>> iterable, Collector<Tuple4<DenseVector, DenseVector, DenseMatrix, Double>> collector) throws Exception {
            OptimObjFunc objFunction = OptimObjFunc.getObjFunction(LinearModelType.LR, new Params());
            int size = this.modelData.coefVector.size();
            DenseVector denseVector = this.modelData.coefVector;
            DenseMatrix denseMatrix = new DenseMatrix(size, size);
            DenseVector denseVector2 = new DenseVector(size);
            collector.collect(Tuple4.of(denseVector, denseVector2, denseMatrix, objFunction.calcHessianGradientLoss(iterable, denseVector, denseMatrix, denseVector2).f1));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/linear/ModelSummaryHelper$CalcLrSummary.class */
    public static class CalcLrSummary extends RichMapPartitionFunction<Tuple4<DenseVector, DenseVector, DenseMatrix, Double>, LogistRegressionSummary> {
        private static final long serialVersionUID = 1799381476070262842L;
        private BaseVectorSummarizer srt;

        public void open(Configuration configuration) {
            this.srt = (BaseVectorSummarizer) getRuntimeContext().getBroadcastVariable("Summarizer").get(0);
        }

        public void mapPartition(Iterable<Tuple4<DenseVector, DenseVector, DenseMatrix, Double>> iterable, Collector<LogistRegressionSummary> collector) throws Exception {
            Iterator<Tuple4<DenseVector, DenseVector, DenseMatrix, Double>> it = iterable.iterator();
            if (it.hasNext()) {
                collector.collect(LocalLinearModel.calcLrSummary(it.next(), this.srt));
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/linear/ModelSummaryHelper$TransformLrLabel.class */
    public static class TransformLrLabel extends RichMapFunction<Row, Tuple3<Double, Double, Vector>> {
        private static final long serialVersionUID = -1178726287840080632L;
        private LinearModelData modelData;
        private String positiveLableValueString;
        private int vecIdx;
        private int[] featureIndices;
        private int labelIdx;
        private int weightIdx;

        public TransformLrLabel(int i, int[] iArr, int i2, int i3) {
            this.vecIdx = i;
            this.featureIndices = iArr;
            this.labelIdx = i2;
            this.weightIdx = i3;
        }

        public void open(Configuration configuration) {
            this.modelData = (LinearModelData) getRuntimeContext().getBroadcastVariable("linearModelData").get(0);
            this.positiveLableValueString = this.modelData.labelValues[0].toString();
        }

        public Tuple3<Double, Double, Vector> map(Row row) throws Exception {
            return ModelSummaryHelper.transferLabel(row, this.featureIndices, this.vecIdx, this.labelIdx, this.weightIdx, this.positiveLableValueString);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/linear/ModelSummaryHelper$TransformLrLabelWithLabel.class */
    public static class TransformLrLabelWithLabel extends RichMapFunction<Row, Row> {
        private static final long serialVersionUID = 9009298353079015437L;
        private String positiveLableValueString;
        private int labelIdx;

        public TransformLrLabelWithLabel(int i, String str) {
            this.positiveLableValueString = null;
            this.labelIdx = i;
            this.positiveLableValueString = str;
        }

        public void open(Configuration configuration) throws Exception {
            this.positiveLableValueString = ModelSummaryHelper.orderLabels(getRuntimeContext().getBroadcastVariable("labels"), this.positiveLableValueString)[0].toString();
        }

        public Row map(Row row) throws Exception {
            row.setField(this.labelIdx, Double.valueOf(FeatureLabelUtil.getLabelValue(row, false, this.labelIdx, this.positiveLableValueString)));
            return row;
        }
    }

    public static DataSet<SelectorResult> calModelSummary(BatchOperator batchOperator, HasConstrainedLinearModelType.LinearModelType linearModelType, DataSet<LinearModelData> dataSet, String str, final String[] strArr, String str2) {
        return HasConstrainedLinearModelType.LinearModelType.LR == linearModelType ? calBinarySummary(batchOperator, dataSet, str, strArr, str2).map(new MapFunction<LogistRegressionSummary, SelectorResult>() { // from class: com.alibaba.alink.operator.common.linear.ModelSummaryHelper.1
            private static final long serialVersionUID = -3321750369444781812L;

            public SelectorResult map(LogistRegressionSummary logistRegressionSummary) throws Exception {
                ClassificationSelectorResult classificationSelectorResult = new ClassificationSelectorResult();
                classificationSelectorResult.modelSummary = logistRegressionSummary;
                classificationSelectorResult.selectedCols = strArr;
                return classificationSelectorResult;
            }
        }) : calRegSummary(batchOperator, dataSet, str, strArr, str2).map(new MapFunction<LinearRegressionSummary, SelectorResult>() { // from class: com.alibaba.alink.operator.common.linear.ModelSummaryHelper.2
            private static final long serialVersionUID = 2198755005878386785L;

            public SelectorResult map(LinearRegressionSummary linearRegressionSummary) throws Exception {
                RegressionSelectorResult regressionSelectorResult = new RegressionSelectorResult();
                regressionSelectorResult.modelSummary = linearRegressionSummary;
                regressionSelectorResult.selectedCols = strArr;
                return regressionSelectorResult;
            }
        });
    }

    public static DataSet<LinearRegressionSummary> calRegSummary(BatchOperator batchOperator, DataSet<LinearModelData> dataSet, String str, String[] strArr, String str2) {
        DataSet<BaseVectorSummarizer> dataSet2 = null;
        if (null != str && !str.isEmpty()) {
            dataSet2 = StatisticsHelper.summarizer(batchOperator.getDataSet().map(new BaseStepWiseSelectorBatchOp.ToVectorWithReservedCols(TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str), TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str2))), true);
        }
        if (null != strArr && strArr.length != 0) {
            String[] strArr2 = new String[strArr.length + 1];
            strArr2[0] = str2;
            System.arraycopy(strArr, 0, strArr2, 1, strArr.length);
            dataSet2 = StatisticsHelper.summarizer(StatisticsHelper.transformToVector(batchOperator, strArr2, null), true);
        }
        if (null == dataSet2) {
            throw new InvalidParameterException("select col and select cols must be set one");
        }
        return dataSet2.flatMap(new CalRegSummary()).withBroadcastSet(dataSet, "linearModelData");
    }

    public static DataSet<LogistRegressionSummary> calBinarySummary(BatchOperator batchOperator, DataSet<LinearModelData> dataSet, String str, String[] strArr, String str2) {
        int[] iArr = null;
        int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str2);
        if (strArr != null) {
            iArr = new int[strArr.length];
            for (int i = 0; i < strArr.length; i++) {
                int findColIndexWithAssertAndHint2 = TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), strArr[i]);
                iArr[i] = findColIndexWithAssertAndHint2;
                TypeInformation typeInformation = batchOperator.getSchema().getFieldTypes()[findColIndexWithAssertAndHint2];
                Preconditions.checkState(TableUtil.isSupportedNumericType(typeInformation), "linear algorithm only support numerical data type. type is : " + typeInformation);
            }
        }
        Operator name = batchOperator.getDataSet().map(new TransformLrLabel(str != null ? TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str) : -1, iArr, findColIndexWithAssertAndHint, -1)).withBroadcastSet(dataSet, "linearModelData").name("TransferLrData");
        return name.mapPartition(new CalcGradientAndHessian()).withBroadcastSet(dataSet, "linearModelData").reduce(new ReduceFunction<Tuple4<DenseVector, DenseVector, DenseMatrix, Double>>() { // from class: com.alibaba.alink.operator.common.linear.ModelSummaryHelper.4
            private static final long serialVersionUID = -9187304403661961376L;

            public Tuple4<DenseVector, DenseVector, DenseMatrix, Double> reduce(Tuple4<DenseVector, DenseVector, DenseMatrix, Double> tuple4, Tuple4<DenseVector, DenseVector, DenseMatrix, Double> tuple42) throws Exception {
                return Tuple4.of(tuple4.f0, ((DenseVector) tuple4.f1).plus((Vector) tuple42.f1), ((DenseMatrix) tuple4.f2).plus((DenseMatrix) tuple42.f2), Double.valueOf(((Double) tuple4.f3).doubleValue() + ((Double) tuple42.f3).doubleValue()));
            }
        }).name("combine gradient and hessian").mapPartition(new CalcLrSummary()).withBroadcastSet(StatisticsHelper.summarizer(name.map(new MapFunction<Tuple3<Double, Double, Vector>, Vector>() { // from class: com.alibaba.alink.operator.common.linear.ModelSummaryHelper.3
            private static final long serialVersionUID = -1205844850032698897L;

            public Vector map(Tuple3<Double, Double, Vector> tuple3) throws Exception {
                ((Vector) tuple3.f2).set(0, ((Double) tuple3.f0).doubleValue());
                return (Vector) tuple3.f2;
            }
        }), false), "Summarizer");
    }

    public static Tuple2<BatchOperator, DataSet<Object>> transformLrLabel(BatchOperator batchOperator, String str, String str2, Long l) {
        int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str);
        DataSet<Object> labelInfo = getLabelInfo(batchOperator, str);
        SingleInputUdfOperator withBroadcastSet = batchOperator.getDataSet().map(new TransformLrLabelWithLabel(findColIndexWithAssertAndHint, str2)).withBroadcastSet(labelInfo, "labels");
        TypeInformation<?>[] colTypes = batchOperator.getColTypes();
        colTypes[findColIndexWithAssertAndHint] = Types.DOUBLE;
        return Tuple2.of(new TableSourceBatchOp(DataSetConversionUtil.toTable(l, (DataSet<Row>) withBroadcastSet, batchOperator.getColNames(), colTypes)), labelInfo);
    }

    public static Object[] orderLabels(Iterable<Object> iterable, String str) {
        ArrayList arrayList = new ArrayList();
        Iterator<Object> it = iterable.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next());
        }
        Object[] array = arrayList.toArray(new Object[0]);
        Preconditions.checkState(array.length == 2, "labels count should be 2 in 2 classification algo.");
        String obj = array[0].toString();
        String obj2 = array[1].toString();
        String str2 = str;
        if (str2 == null) {
            str2 = obj2.compareTo(obj) > 0 ? obj2 : obj;
        }
        if (array[1].toString().equals(str2)) {
            Object obj3 = array[0];
            array[0] = array[1];
            array[1] = obj3;
        }
        return array;
    }

    public static Boolean isLinearRegression(String str) {
        String upperCase = str.trim().toUpperCase();
        if (upperCase.equals("LINEARREG") || upperCase.equals("LR")) {
            return Boolean.valueOf("LINEARREG".equals(upperCase));
        }
        throw new RuntimeException("model type not support. " + upperCase);
    }

    public static DataSet<Object> getLabelInfo(BatchOperator batchOperator, String str) {
        return batchOperator.select(new String[]{str}).distinct().getDataSet().map(new MapFunction<Row, Object>() { // from class: com.alibaba.alink.operator.common.linear.ModelSummaryHelper.5
            private static final long serialVersionUID = 2044498497762182626L;

            public Object map(Row row) {
                return row.getField(0);
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Tuple3<Double, Double, Vector> transferLabel(Row row, int[] iArr, int i, int i2, int i3, String str) throws Exception {
        Tuple3 of;
        Double valueOf = Double.valueOf(i3 != -1 ? ((Number) row.getField(i3)).doubleValue() : 1.0d);
        Double valueOf2 = Double.valueOf(FeatureLabelUtil.getLabelValue(row, false, i2, str));
        if (iArr != null) {
            DenseVector denseVector = new DenseVector(iArr.length);
            for (int i4 = 0; i4 < iArr.length; i4++) {
                denseVector.set(i4, ((Number) row.getField(iArr[i4])).doubleValue());
            }
            of = Tuple3.of(valueOf, valueOf2, denseVector);
        } else {
            Vector vector = VectorUtil.getVector(row.getField(i));
            Preconditions.checkState(vector != null, "vector for linear model train is null, please check your input data.");
            of = Tuple3.of(valueOf, valueOf2, vector);
        }
        return Tuple3.of(of.f0, of.f1, ((Vector) of.f2).prefix(1.0d));
    }
}
