package com.alibaba.alink.operator.batch.statistics.utils;

import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.dataproc.SortUtils;
import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummarizer;
import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary;
import com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationResult;
import com.alibaba.alink.operator.common.statistics.basicstatistic.DenseVectorSummarizer;
import com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummarizer;
import com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummary;
import com.alibaba.alink.operator.common.statistics.basicstatistic.VectorSummarizerUtil;
import com.alibaba.alink.operator.common.statistics.statistics.SrtUtil;
import com.alibaba.alink.operator.common.statistics.statistics.SummaryResultTable;
import com.alibaba.alink.operator.common.statistics.statistics.WindowTable;
import com.alibaba.alink.operator.common.tree.Preprocessing;
import com.alibaba.alink.params.statistics.HasStatLevel_L1;
import java.util.Iterator;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/utils/StatisticsHelper.class */
public class StatisticsHelper {

    /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/utils/StatisticsHelper$ColsToDoubleColsMap.class */
    private static class ColsToDoubleColsMap implements MapFunction<Row, Row> {
        private static final long serialVersionUID = 2021889928304298454L;
        private final int[] selectedColIndices;
        private final int[] reservedColIndices;

        ColsToDoubleColsMap(int[] iArr, int[] iArr2) {
            this.selectedColIndices = iArr;
            this.reservedColIndices = null == iArr2 ? new int[0] : iArr2;
        }

        public Row map(Row row) throws Exception {
            Row row2 = new Row(this.selectedColIndices.length + this.reservedColIndices.length);
            for (int i = 0; i < this.selectedColIndices.length; i++) {
                row2.setField(i, Double.valueOf(((Number) row.getField(this.selectedColIndices[i])).doubleValue()));
            }
            for (int i2 = 0; i2 < this.reservedColIndices.length; i2++) {
                row2.setField(i2 + this.selectedColIndices.length, row.getField(this.reservedColIndices[i2]));
            }
            return row2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/utils/StatisticsHelper$ColsToVectorWithReservedColsMap.class */
    public static class ColsToVectorWithReservedColsMap implements MapFunction<Row, Tuple2<Vector, Row>> {
        private static final long serialVersionUID = -7292044433828115396L;
        private final int[] selectedColIndices;
        private final int[] reservedColIndices;

        ColsToVectorWithReservedColsMap(int[] iArr, int[] iArr2) {
            this.selectedColIndices = iArr;
            this.reservedColIndices = iArr2;
        }

        public Tuple2<Vector, Row> map(Row row) throws Exception {
            DenseVector denseVector = new DenseVector(this.selectedColIndices.length);
            for (int i = 0; i < this.selectedColIndices.length; i++) {
                denseVector.set(i, ((Number) row.getField(this.selectedColIndices[i])).doubleValue());
            }
            Row row2 = new Row(this.reservedColIndices.length);
            for (int i2 = 0; i2 < this.reservedColIndices.length; i2++) {
                row2.setField(i2, row.getField(this.reservedColIndices[i2]));
            }
            return new Tuple2<>(denseVector, row2);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/utils/StatisticsHelper$ColsToVectorWithoutReservedColsMap.class */
    public static class ColsToVectorWithoutReservedColsMap implements MapFunction<Row, Vector> {
        private static final long serialVersionUID = -8479361651447801687L;
        private final int[] selectedColIndices;

        ColsToVectorWithoutReservedColsMap(int[] iArr) {
            this.selectedColIndices = iArr;
        }

        public Vector map(Row row) throws Exception {
            DenseVector denseVector = new DenseVector(this.selectedColIndices.length);
            double[] data = denseVector.getData();
            for (int i = 0; i < this.selectedColIndices.length; i++) {
                Object field = row.getField(this.selectedColIndices[i]);
                if (field instanceof Number) {
                    data[i] = ((Number) field).doubleValue();
                }
            }
            return denseVector;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/utils/StatisticsHelper$SetPartitionBasicStat.class */
    public static class SetPartitionBasicStat implements MapPartitionFunction<Row, SummaryResultTable> {
        private static final long serialVersionUID = -5607403479996476267L;
        private String[] colNames;
        private Class[] colTypes;
        private HasStatLevel_L1.StatLevel statLevel;
        private String[] selectedColNames;

        public SetPartitionBasicStat(TableSchema tableSchema) {
            this(tableSchema, HasStatLevel_L1.StatLevel.L1);
        }

        public SetPartitionBasicStat(TableSchema tableSchema, HasStatLevel_L1.StatLevel statLevel) {
            this.selectedColNames = null;
            this.colNames = tableSchema.getFieldNames();
            int length = this.colNames.length;
            this.colTypes = new Class[length];
            for (int i = 0; i < length; i++) {
                this.colTypes[i] = tableSchema.getFieldTypes()[i].getTypeClass();
            }
            this.statLevel = statLevel;
            this.selectedColNames = this.colNames;
        }

        public SetPartitionBasicStat(TableSchema tableSchema, String[] strArr, HasStatLevel_L1.StatLevel statLevel) {
            this.selectedColNames = null;
            this.colNames = tableSchema.getFieldNames();
            int length = this.colNames.length;
            this.colTypes = new Class[length];
            for (int i = 0; i < length; i++) {
                this.colTypes[i] = tableSchema.getFieldTypes()[i].getTypeClass();
            }
            this.statLevel = statLevel;
            this.selectedColNames = strArr;
        }

        public void mapPartition(Iterable<Row> iterable, Collector<SummaryResultTable> collector) throws Exception {
            SummaryResultTable batchSummary = SrtUtil.batchSummary(new WindowTable(this.colNames, this.colTypes, iterable), this.selectedColNames, 10, 10, SortUtils.SPLIT_POINT_SIZE, 10, this.statLevel);
            if (batchSummary != null) {
                collector.collect(batchSummary);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/utils/StatisticsHelper$TableSummarizerPartition.class */
    public static class TableSummarizerPartition implements MapPartitionFunction<Row, TableSummarizer> {
        private static final long serialVersionUID = -1625614901816383530L;
        private final boolean outerProduct;
        private final String[] colNames;
        private final TypeInformation<?>[] colTypes;

        TableSummarizerPartition(TableSchema tableSchema, boolean z) {
            this.outerProduct = z;
            this.colNames = tableSchema.getFieldNames();
            this.colTypes = tableSchema.getFieldTypes();
        }

        public void mapPartition(Iterable<Row> iterable, Collector<TableSummarizer> collector) throws Exception {
            TableSummarizer tableSummarizer = new TableSummarizer(new TableSchema(this.colNames, this.colTypes), this.outerProduct);
            Iterator<Row> it = iterable.iterator();
            while (it.hasNext()) {
                tableSummarizer = (TableSummarizer) tableSummarizer.visit(it.next());
            }
            collector.collect(tableSummarizer);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/utils/StatisticsHelper$VectorCoToVectorWithoutReservedColsMap.class */
    public static class VectorCoToVectorWithoutReservedColsMap implements MapFunction<Row, Vector> {
        private static final long serialVersionUID = -6220416346174572528L;
        private final int vectorColIndex;

        VectorCoToVectorWithoutReservedColsMap(int i) {
            this.vectorColIndex = i;
        }

        public Vector map(Row row) throws Exception {
            return VectorUtil.getVector(row.getField(this.vectorColIndex));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/utils/StatisticsHelper$VectorColToTableMap.class */
    private static class VectorColToTableMap implements MapFunction<Row, Row> {
        private static final long serialVersionUID = 4663377654190680333L;
        private final int vectorColIndex;
        private final int[] reservedColIndices;

        VectorColToTableMap(int i, int[] iArr) {
            this.vectorColIndex = i;
            this.reservedColIndices = null == iArr ? new int[0] : iArr;
        }

        public Row map(Row row) throws Exception {
            Row row2;
            Vector vector = VectorUtil.getVector(row.getField(this.vectorColIndex));
            DenseVector denseVector = vector instanceof DenseVector ? (DenseVector) vector : ((SparseVector) vector).toDenseVector();
            if (denseVector.getData() != null) {
                row2 = new Row(denseVector.size() + this.reservedColIndices.length);
                for (int i = 0; i < denseVector.size(); i++) {
                    row2.setField(i, Double.valueOf(denseVector.get(i)));
                }
                for (int i2 = 0; i2 < this.reservedColIndices.length; i2++) {
                    row2.setField(i2 + denseVector.size(), row.getField(this.reservedColIndices[i2]));
                }
            } else {
                row2 = new Row(this.reservedColIndices.length);
                for (int i3 = 0; i3 < this.reservedColIndices.length; i3++) {
                    row2.setField(i3, row.getField(this.reservedColIndices[i3]));
                }
            }
            return row2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/utils/StatisticsHelper$VectorColToVectorWithReservedColsMap.class */
    public static class VectorColToVectorWithReservedColsMap implements MapFunction<Row, Tuple2<Vector, Row>> {
        private static final long serialVersionUID = -3222351920500305742L;
        private final int vectorColIndex;
        private final int[] reservedColIndices;

        VectorColToVectorWithReservedColsMap(int i, int[] iArr) {
            this.vectorColIndex = i;
            this.reservedColIndices = iArr;
        }

        public Tuple2<Vector, Row> map(Row row) throws Exception {
            Vector vector = VectorUtil.getVector(row.getField(this.vectorColIndex));
            if (vector == null) {
                throw new AkIllegalDataException("input vector is null");
            }
            Row row2 = new Row(this.reservedColIndices.length);
            for (int i = 0; i < this.reservedColIndices.length; i++) {
                row2.setField(i, row.getField(this.reservedColIndices[i]));
            }
            return Tuple2.of(vector, row2);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/utils/StatisticsHelper$VectorSummarizerPartition.class */
    public static class VectorSummarizerPartition implements MapPartitionFunction<Vector, BaseVectorSummarizer> {
        private static final long serialVersionUID = 1065284716432882945L;
        private final boolean outerProduct;

        public VectorSummarizerPartition(boolean z) {
            this.outerProduct = z;
        }

        public void mapPartition(Iterable<Vector> iterable, Collector<BaseVectorSummarizer> collector) throws Exception {
            DenseVectorSummarizer denseVectorSummarizer = new DenseVectorSummarizer(this.outerProduct);
            Iterator<Vector> it = iterable.iterator();
            while (it.hasNext()) {
                denseVectorSummarizer = denseVectorSummarizer.visit(it.next());
            }
            collector.collect(denseVectorSummarizer);
        }
    }

    public static Tuple2<DataSet<Tuple2<Vector, Row>>, DataSet<BaseVectorSummary>> summaryHelper(BatchOperator<?> batchOperator, String[] strArr, String str, String[] strArr2) {
        DataSet<Tuple2<Vector, Row>> transformToVector = transformToVector(batchOperator, strArr, str, strArr2);
        return Tuple2.of(transformToVector, summary(transformToVector.map(new MapFunction<Tuple2<Vector, Row>, Vector>() { // from class: com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper.1
            private static final long serialVersionUID = -1465299071490594701L;

            public Vector map(Tuple2<Vector, Row> tuple2) {
                return (Vector) tuple2.f0;
            }
        })));
    }

    public static Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> summaryHelper(BatchOperator<?> batchOperator, String[] strArr, String str) {
        DataSet<Vector> transformToVector = transformToVector(batchOperator, strArr, str);
        return Tuple2.of(transformToVector, summary(transformToVector));
    }

    public static DataSet<TableSummary> summary(BatchOperator<?> batchOperator, String[] strArr) {
        return summarizer(batchOperator, strArr, false).map(new MapFunction<TableSummarizer, TableSummary>() { // from class: com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper.2
            private static final long serialVersionUID = 8876418210242735806L;

            public TableSummary map(TableSummarizer tableSummarizer) {
                return tableSummarizer.toSummary();
            }
        }).name("toSummary");
    }

    public static DataSet<BaseVectorSummary> vectorSummary(BatchOperator<?> batchOperator, String str) {
        return vectorSummarizer(batchOperator, str, false).map(new MapFunction<BaseVectorSummarizer, BaseVectorSummary>() { // from class: com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper.3
            private static final long serialVersionUID = -6426572658193278213L;

            public BaseVectorSummary map(BaseVectorSummarizer baseVectorSummarizer) {
                return baseVectorSummarizer.toSummary();
            }
        }).name("toSummary");
    }

    public static DataSet<BaseVectorSummary> summary(DataSet<Vector> dataSet) {
        return summarizer(dataSet, false).map(new MapFunction<BaseVectorSummarizer, BaseVectorSummary>() { // from class: com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper.4
            private static final long serialVersionUID = -2082435777065038687L;

            public BaseVectorSummary map(BaseVectorSummarizer baseVectorSummarizer) {
                return baseVectorSummarizer.toSummary();
            }
        }).name("toSummary");
    }

    public static DataSet<Tuple2<TableSummary, CorrelationResult>> pearsonCorrelation(BatchOperator<?> batchOperator, String[] strArr) {
        return summarizer(batchOperator, strArr, true).map(new MapFunction<TableSummarizer, Tuple2<TableSummary, CorrelationResult>>() { // from class: com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper.5
            private static final long serialVersionUID = -5757257375097509823L;

            public Tuple2<TableSummary, CorrelationResult> map(TableSummarizer tableSummarizer) {
                return Tuple2.of(tableSummarizer.toSummary(), tableSummarizer.correlation());
            }
        });
    }

    public static DataSet<Tuple2<BaseVectorSummary, CorrelationResult>> vectorPearsonCorrelation(BatchOperator<?> batchOperator, String str) {
        return vectorSummarizer(batchOperator, str, true).map(new MapFunction<BaseVectorSummarizer, Tuple2<BaseVectorSummary, CorrelationResult>>() { // from class: com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper.6
            private static final long serialVersionUID = -1745468840082156193L;

            public Tuple2<BaseVectorSummary, CorrelationResult> map(BaseVectorSummarizer baseVectorSummarizer) {
                return Tuple2.of(baseVectorSummarizer.toSummary(), baseVectorSummarizer.correlation());
            }
        });
    }

    public static DataSet<Vector> transformToVector(BatchOperator<?> batchOperator, String[] strArr, String str) {
        checkSimpleStatParameter(batchOperator, strArr, str, null);
        if (strArr == null || strArr.length == 0) {
            return batchOperator.getDataSet().map(new VectorCoToVectorWithoutReservedColsMap(TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str)));
        }
        return batchOperator.getDataSet().map(new ColsToVectorWithoutReservedColsMap(TableUtil.findColIndicesWithAssertAndHint(batchOperator.getColNames(), strArr)));
    }

    public static DataSet<Tuple2<Vector, Row>> transformToVector(BatchOperator<?> batchOperator, String[] strArr, String str, String[] strArr2) {
        checkSimpleStatParameter(batchOperator, strArr, str, strArr2);
        int[] findColIndicesWithAssertAndHint = TableUtil.findColIndicesWithAssertAndHint(batchOperator.getColNames(), strArr2);
        if (strArr == null || strArr.length == 0) {
            return batchOperator.getDataSet().map(new VectorColToVectorWithReservedColsMap(TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str), findColIndicesWithAssertAndHint)).name("transform_data");
        }
        return batchOperator.getDataSet().map(new ColsToVectorWithReservedColsMap(TableUtil.findColIndicesWithAssertAndHint(batchOperator.getColNames(), strArr), findColIndicesWithAssertAndHint)).name("transform_data");
    }

    public static DataSet<Row> transformToColumns(BatchOperator<?> batchOperator, String[] strArr, String str, String[] strArr2) {
        checkSimpleStatParameter(batchOperator, strArr, str, strArr2);
        int[] iArr = null;
        if (strArr2 != null) {
            iArr = TableUtil.findColIndicesWithAssertAndHint(batchOperator.getColNames(), strArr2);
        }
        if (strArr != null && strArr.length != 0) {
            return batchOperator.getDataSet().map(new ColsToDoubleColsMap(TableUtil.findColIndicesWithAssertAndHint(batchOperator.getSchema(), strArr), iArr));
        }
        if (str == null) {
            throw new AkIllegalOperatorParameterException("selectedColName and vectorColName must be set one only.");
        }
        return batchOperator.getDataSet().map(new VectorColToTableMap(TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str), iArr));
    }

    private static void checkSimpleStatParameter(BatchOperator<?> batchOperator, String[] strArr, String str, String[] strArr2) {
        if (strArr != null && strArr.length != 0 && str != null) {
            throw new AkIllegalOperatorParameterException("selectedColName and vectorColName must be set one only.");
        }
        TableUtil.assertSelectedColExist(batchOperator.getColNames(), strArr);
        TableUtil.assertNumericalCols(batchOperator.getSchema(), strArr);
        TableUtil.assertSelectedColExist(batchOperator.getColNames(), str);
        TableUtil.assertVectorCols(batchOperator.getSchema(), str);
        TableUtil.assertSelectedColExist(batchOperator.getColNames(), strArr2);
    }

    private static DataSet<TableSummarizer> summarizer(BatchOperator<?> batchOperator, String[] strArr, boolean z) {
        if (strArr == null || strArr.length == 0) {
            throw new AkIllegalOperatorParameterException("selectedColNames must be set.");
        }
        BatchOperator<?> select = Preprocessing.select(batchOperator, strArr);
        return summarizer(select.getDataSet(), select.getSchema(), z);
    }

    private static DataSet<BaseVectorSummarizer> vectorSummarizer(BatchOperator<?> batchOperator, String str, boolean z) {
        TableUtil.assertSelectedColExist(batchOperator.getColNames(), str);
        return summarizer(transformToVector(batchOperator, null, str), z);
    }

    public static DataSet<BaseVectorSummarizer> summarizer(DataSet<Vector> dataSet, boolean z) {
        return dataSet.mapPartition(new VectorSummarizerPartition(z)).name("summarizer_map").reduce(new ReduceFunction<BaseVectorSummarizer>() { // from class: com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper.7
            private static final long serialVersionUID = 5993118429985684366L;

            public BaseVectorSummarizer reduce(BaseVectorSummarizer baseVectorSummarizer, BaseVectorSummarizer baseVectorSummarizer2) {
                return VectorSummarizerUtil.merge(baseVectorSummarizer, baseVectorSummarizer2);
            }
        }).name("summarizer_reduce");
    }

    private static DataSet<TableSummarizer> summarizer(DataSet<Row> dataSet, TableSchema tableSchema, boolean z) {
        return dataSet.mapPartition(new TableSummarizerPartition(tableSchema, z)).reduce(new ReduceFunction<TableSummarizer>() { // from class: com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper.8
            private static final long serialVersionUID = 964700189305139868L;

            public TableSummarizer reduce(TableSummarizer tableSummarizer, TableSummarizer tableSummarizer2) {
                return TableSummarizer.merge(tableSummarizer, tableSummarizer2);
            }
        });
    }

    public static DataSet<SummaryResultTable> getSRT(BatchOperator<?> batchOperator, HasStatLevel_L1.StatLevel statLevel) {
        return batchOperator.getDataSet().mapPartition(new SetPartitionBasicStat(batchOperator.getSchema(), statLevel)).reduce(new ReduceFunction<SummaryResultTable>() { // from class: com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper.9
            private static final long serialVersionUID = 6050967884386340459L;

            public SummaryResultTable reduce(SummaryResultTable summaryResultTable, SummaryResultTable summaryResultTable2) {
                return null == summaryResultTable ? summaryResultTable2 : null == summaryResultTable2 ? summaryResultTable : SummaryResultTable.combine(summaryResultTable, summaryResultTable2);
            }
        });
    }
}
