package com.alibaba.alink.operator.common.statistics.basicstatistic;

import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.tree.Criteria;

/* loaded from: input_file:com/alibaba/alink/operator/common/statistics/basicstatistic/BaseVectorSummarizer.class */
public abstract class BaseVectorSummarizer extends BaseSummarizer {
    private static final long serialVersionUID = -6594023541408617732L;

    public abstract BaseVectorSummarizer visit(Vector vector);

    public abstract BaseVectorSummary toSummary();

    @Override // com.alibaba.alink.operator.common.statistics.basicstatistic.BaseSummarizer
    public DenseMatrix covariance() {
        if (this.outerProduct == null) {
            return null;
        }
        Vector sum = toSummary().sum();
        int size = sum.size();
        int numRows = this.outerProduct.numRows();
        double[][] dArr = new double[size][size];
        for (int i = 0; i < size && i < numRows; i++) {
            for (int i2 = i; i2 < size && i2 < numRows; i2++) {
                double d = (this.outerProduct.get(i, i2) - ((sum.get(i) * sum.get(i2)) / this.count)) / (this.count - 1);
                dArr[i][i2] = d;
                dArr[i2][i] = d;
            }
        }
        return new DenseMatrix(dArr);
    }

    @Override // com.alibaba.alink.operator.common.statistics.basicstatistic.BaseSummarizer
    public CorrelationResult correlation() {
        if (this.outerProduct == null) {
            return null;
        }
        DenseMatrix covariance = covariance();
        Vector standardDeviation = toSummary().standardDeviation();
        int numRows = covariance.numRows();
        for (int i = 0; i < numRows; i++) {
            for (int i2 = i; i2 < numRows; i2++) {
                double d = covariance.get(i, i2);
                if (Double.isNaN(d) || d == Criteria.INVALID_GAIN) {
                    covariance.set(i, i2, Criteria.INVALID_GAIN);
                    covariance.set(i2, i, Criteria.INVALID_GAIN);
                } else {
                    double d2 = (d / standardDeviation.get(i)) / standardDeviation.get(i2);
                    covariance.set(i, i2, d2);
                    covariance.set(i2, i, d2);
                }
            }
        }
        for (int i3 = 0; i3 < numRows; i3++) {
            covariance.set(i3, i3, 1.0d);
        }
        for (int i4 = 0; i4 < numRows; i4++) {
            for (int i5 = 0; i5 < numRows; i5++) {
                if (i4 != i5) {
                    if (covariance.get(i4, i5) > 1.0d) {
                        covariance.set(i4, i5, 1.0d);
                    } else if (covariance.get(i4, i5) < -1.0d) {
                        covariance.set(i4, i5, -1.0d);
                    }
                }
            }
        }
        return new CorrelationResult(covariance);
    }
}
