package com.alibaba.alink.operator.common.statistics.basicstatistic;

import com.alibaba.alink.common.exceptions.AkIllegalStateException;
import com.alibaba.alink.common.io.filesystem.copy.csv.CsvInputFormat;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.statistics.statistics.BaseMeasureIterator;
import com.alibaba.alink.operator.common.statistics.statistics.BooleanMeasureIterator;
import com.alibaba.alink.operator.common.statistics.statistics.DateMeasureIterator;
import com.alibaba.alink.operator.common.statistics.statistics.NumberMeasureIterator;
import com.alibaba.alink.operator.common.statistics.statistics.StatisticsIteratorFactory;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.ArrayList;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/statistics/basicstatistic/TableSummarizer.class */
public class TableSummarizer extends BaseSummarizer {
    private static final long serialVersionUID = 4588962274305185787L;
    public String[] colNames;
    TypeInformation<?>[] colTypes;
    BaseMeasureIterator[] statIterators;
    private int n;
    private int numberN;
    private int[] numericalColIndices;
    private Double[] vals;
    DenseMatrix xSum;
    DenseMatrix xSquareSum;
    DenseMatrix xyCount;

    private TableSummarizer() {
    }

    public TableSummarizer(TableSchema tableSchema, boolean z) {
        this.colNames = tableSchema.getFieldNames();
        this.colTypes = tableSchema.getFieldTypes();
        this.calculateOuterProduct = z;
        this.n = this.colNames.length;
        this.numericalColIndices = calcCovColIndices(new TableSchema(this.colNames, this.colTypes));
        this.numberN = this.numericalColIndices.length;
    }

    public BaseSummarizer visit(Row row) {
        if (this.n != row.getArity()) {
            throw new AkIllegalStateException("row size is not equal with table col num.");
        }
        if (this.count == 0) {
            init();
        }
        this.count++;
        for (int i = 0; i < this.n; i++) {
            this.statIterators[i].visit(row.getField(i));
        }
        if (this.calculateOuterProduct) {
            for (int i2 = 0; i2 < this.numberN; i2++) {
                Object field = row.getField(this.numericalColIndices[i2]);
                if (field == null) {
                    this.vals[i2] = null;
                } else if (field instanceof Boolean) {
                    this.vals[i2] = Double.valueOf(((Boolean) field).booleanValue() ? 1.0d : Criteria.INVALID_GAIN);
                } else {
                    this.vals[i2] = Double.valueOf(((Number) field).doubleValue());
                }
            }
            for (int i3 = 0; i3 < this.numberN; i3++) {
                if (this.vals[i3] != null) {
                    double doubleValue = this.vals[i3].doubleValue();
                    for (int i4 = i3; i4 < this.numberN; i4++) {
                        if (this.vals[i4] != null) {
                            this.outerProduct.add(i3, i4, doubleValue * this.vals[i4].doubleValue());
                            this.xSum.add(i3, i4, doubleValue);
                            this.xSquareSum.add(i3, i4, doubleValue * doubleValue);
                            this.xyCount.add(i3, i4, 1.0d);
                            if (i4 != i3) {
                                this.xSum.add(i4, i3, this.vals[i4].doubleValue());
                                this.xSquareSum.add(i4, i3, this.vals[i4].doubleValue() * this.vals[i4].doubleValue());
                                this.xyCount.add(i4, i3, 1.0d);
                            }
                        }
                    }
                }
            }
        }
        return this;
    }

    public String toString() {
        StringBuilder append = new StringBuilder().append("count: ").append(this.count).append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
        if (this.count != 0) {
            for (int i = 0; i < this.n; i++) {
                append.append(this.colNames[i]).append(": ").append(this.statIterators[i]);
            }
        }
        return append.toString();
    }

    public TableSummary toSummary() {
        TableSummary tableSummary = new TableSummary();
        tableSummary.numericalColIndices = this.numericalColIndices;
        tableSummary.colNames = this.colNames;
        tableSummary.count = this.count;
        tableSummary.sum = new DenseVector(this.numberN);
        tableSummary.sum2 = new DenseVector(this.numberN);
        tableSummary.sum3 = new DenseVector(this.numberN);
        tableSummary.sum4 = new DenseVector(this.numberN);
        tableSummary.normL1 = new DenseVector(this.numberN);
        tableSummary.minDouble = new DenseVector(this.numberN);
        tableSummary.maxDouble = new DenseVector(this.numberN);
        tableSummary.numMissingValue = new long[this.n];
        tableSummary.min = new Object[this.numberN];
        tableSummary.max = new Object[this.numberN];
        if (this.count > 0) {
            for (int i = 0; i < this.n; i++) {
                tableSummary.numMissingValue[i] = this.statIterators[i].missingCount();
            }
            for (int i2 = 0; i2 < this.numberN; i2++) {
                BaseMeasureIterator baseMeasureIterator = this.statIterators[this.numericalColIndices[i2]];
                if (baseMeasureIterator instanceof NumberMeasureIterator) {
                    NumberMeasureIterator numberMeasureIterator = (NumberMeasureIterator) baseMeasureIterator;
                    tableSummary.sum.set(i2, numberMeasureIterator.sum);
                    tableSummary.sum2.set(i2, numberMeasureIterator.sum2);
                    tableSummary.sum3.set(i2, numberMeasureIterator.sum3);
                    tableSummary.sum4.set(i2, numberMeasureIterator.sum4);
                    tableSummary.minDouble.set(i2, numberMeasureIterator.min.doubleValue());
                    tableSummary.maxDouble.set(i2, numberMeasureIterator.max.doubleValue());
                    tableSummary.normL1.set(i2, numberMeasureIterator.normL1);
                    tableSummary.min[i2] = numberMeasureIterator.min;
                    tableSummary.max[i2] = numberMeasureIterator.max;
                } else if (baseMeasureIterator instanceof BooleanMeasureIterator) {
                    BooleanMeasureIterator booleanMeasureIterator = (BooleanMeasureIterator) baseMeasureIterator;
                    tableSummary.sum.set(i2, booleanMeasureIterator.countTrue);
                    tableSummary.sum2.set(i2, booleanMeasureIterator.countTrue);
                    tableSummary.sum3.set(i2, booleanMeasureIterator.countTrue);
                    tableSummary.sum4.set(i2, booleanMeasureIterator.countTrue);
                    tableSummary.normL1.set(i2, booleanMeasureIterator.countTrue);
                    tableSummary.minDouble.set(i2, booleanMeasureIterator.countFalse > 0 ? Criteria.INVALID_GAIN : 1.0d);
                    tableSummary.maxDouble.set(i2, booleanMeasureIterator.countTrue > 0 ? 1.0d : Criteria.INVALID_GAIN);
                    tableSummary.min[i2] = Boolean.valueOf(booleanMeasureIterator.countFalse <= 0);
                    tableSummary.max[i2] = Boolean.valueOf(booleanMeasureIterator.countTrue > 0);
                } else if (baseMeasureIterator instanceof DateMeasureIterator) {
                    DateMeasureIterator dateMeasureIterator = (DateMeasureIterator) baseMeasureIterator;
                    tableSummary.sum.set(i2, Double.NaN);
                    tableSummary.sum2.set(i2, Double.NaN);
                    tableSummary.sum3.set(i2, Double.NaN);
                    tableSummary.sum4.set(i2, Double.NaN);
                    tableSummary.minDouble.set(i2, dateMeasureIterator.min.getTime());
                    tableSummary.maxDouble.set(i2, dateMeasureIterator.max.getTime());
                    tableSummary.min[i2] = dateMeasureIterator.min;
                    tableSummary.max[i2] = dateMeasureIterator.max;
                }
            }
        }
        return tableSummary;
    }

    @Override // com.alibaba.alink.operator.common.statistics.basicstatistic.BaseSummarizer
    public CorrelationResult correlation() {
        if (this.outerProduct == null) {
            return null;
        }
        DenseMatrix covariance = covariance();
        int numRows = covariance.numRows();
        for (int i = 0; i < this.numericalColIndices.length; i++) {
            int i2 = this.numericalColIndices[i];
            for (int i3 = 0; i3 < this.numericalColIndices.length; i3++) {
                int i4 = this.numericalColIndices[i3];
                double d = covariance.get(i2, i4);
                if (!Double.isNaN(d) && d != Criteria.INVALID_GAIN) {
                    covariance.set(i2, i4, d / Math.sqrt(Math.max(Criteria.INVALID_GAIN, (this.xSquareSum.get(i, i3) - ((this.xSum.get(i, i3) * this.xSum.get(i, i3)) / this.xyCount.get(i, i3))) / (this.xyCount.get(i, i3) - 1.0d)) * Math.max(Criteria.INVALID_GAIN, (this.xSquareSum.get(i3, i) - ((this.xSum.get(i3, i) * this.xSum.get(i3, i)) / this.xyCount.get(i3, i))) / (this.xyCount.get(i3, i) - 1.0d))));
                }
            }
        }
        for (int i5 = 0; i5 < numRows; i5++) {
            for (int i6 = 0; i6 < numRows; i6++) {
                if (!Double.isNaN(covariance.get(i5, i6))) {
                    if (i5 == i6) {
                        covariance.set(i5, i5, 1.0d);
                    } else if (covariance.get(i5, i6) > 1.0d) {
                        covariance.set(i5, i6, 1.0d);
                    } else if (covariance.get(i5, i6) < -1.0d) {
                        covariance.set(i5, i6, -1.0d);
                    }
                }
            }
        }
        return new CorrelationResult(covariance, this.colNames);
    }

    @Override // com.alibaba.alink.operator.common.statistics.basicstatistic.BaseSummarizer
    public DenseMatrix covariance() {
        if (this.outerProduct == null) {
            return null;
        }
        double[][] dArr = new double[this.n][this.n];
        for (int i = 0; i < this.n; i++) {
            for (int i2 = 0; i2 < this.n; i2++) {
                dArr[i][i2] = Double.NaN;
            }
        }
        for (int i3 = 0; i3 < this.numericalColIndices.length; i3++) {
            int i4 = this.numericalColIndices[i3];
            for (int i5 = i3; i5 < this.numericalColIndices.length; i5++) {
                int i6 = this.numericalColIndices[i5];
                double d = this.xyCount.get(i3, i5);
                double d2 = (this.outerProduct.get(i3, i5) - ((this.xSum.get(i3, i5) * this.xSum.get(i5, i3)) / d)) / (d - 1.0d);
                dArr[i4][i6] = d2;
                dArr[i6][i4] = d2;
            }
        }
        return new DenseMatrix(dArr);
    }

    TableSummarizer copy() {
        TableSummarizer tableSummarizer = new TableSummarizer();
        tableSummarizer.colNames = (String[]) this.colNames.clone();
        tableSummarizer.count = this.count;
        if (this.count != 0) {
            tableSummarizer.statIterators = new BaseMeasureIterator[this.n];
            for (int i = 0; i < this.n; i++) {
                tableSummarizer.statIterators[i] = this.statIterators[i].m575clone();
            }
        }
        if (this.outerProduct != null) {
            tableSummarizer.numericalColIndices = (int[]) this.numericalColIndices.clone();
            tableSummarizer.outerProduct = this.outerProduct.m134clone();
            tableSummarizer.xSum = this.xSum.m134clone();
            tableSummarizer.xSquareSum = this.xSquareSum.m134clone();
            tableSummarizer.xyCount = this.xyCount.m134clone();
        }
        return tableSummarizer;
    }

    private void init() {
        this.statIterators = new BaseMeasureIterator[this.n];
        for (int i = 0; i < this.n; i++) {
            this.statIterators[i] = StatisticsIteratorFactory.getMeasureIterator(this.colTypes[i]);
        }
        if (this.calculateOuterProduct) {
            this.vals = new Double[this.numberN];
            this.outerProduct = new DenseMatrix(this.numberN, this.numberN);
            this.xSum = new DenseMatrix(this.numberN, this.numberN);
            this.xSquareSum = new DenseMatrix(this.numberN, this.numberN);
            this.xyCount = new DenseMatrix(this.numberN, this.numberN);
        }
    }

    private int[] calcCovColIndices(TableSchema tableSchema) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < tableSchema.getFieldNames().length; i++) {
            TypeInformation typeInformation = (TypeInformation) tableSchema.getFieldType(i).get();
            if (TableUtil.isSupportedNumericType(typeInformation) || TableUtil.isSupportedBoolType(typeInformation) || TableUtil.isSupportedDateType(typeInformation)) {
                arrayList.add(Integer.valueOf(i));
            }
        }
        return arrayList.stream().mapToInt((v0) -> {
            return Integer.valueOf(v0);
        }).toArray();
    }

    public static TableSummarizer merge(TableSummarizer tableSummarizer, TableSummarizer tableSummarizer2) {
        if (tableSummarizer2.count == 0) {
            return tableSummarizer;
        }
        if (tableSummarizer.count == 0) {
            return tableSummarizer2.copy();
        }
        tableSummarizer.count += tableSummarizer2.count;
        if (tableSummarizer.n != tableSummarizer2.n) {
            throw new AkIllegalStateException("left stat cols is not equal with right stat cols");
        }
        for (int i = 0; i < tableSummarizer.n; i++) {
            tableSummarizer.statIterators[i].merge(tableSummarizer2.statIterators[i]);
        }
        if (tableSummarizer.outerProduct != null && tableSummarizer2.outerProduct != null) {
            tableSummarizer.outerProduct.plusEquals(tableSummarizer2.outerProduct);
            tableSummarizer.xSum.plusEquals(tableSummarizer2.xSum);
            tableSummarizer.xSquareSum.plusEquals(tableSummarizer2.xSquareSum);
            tableSummarizer.xyCount.plusEquals(tableSummarizer2.xyCount);
        } else if (tableSummarizer.outerProduct == null && tableSummarizer2.outerProduct != null) {
            tableSummarizer.outerProduct = tableSummarizer2.outerProduct.m134clone();
            tableSummarizer.xSum = tableSummarizer2.xSum.m134clone();
            tableSummarizer.xSquareSum = tableSummarizer2.xSquareSum.m134clone();
            tableSummarizer.xyCount = tableSummarizer2.xyCount.m134clone();
        }
        return tableSummarizer;
    }
}
