package com.alibaba.alink.operator.common.distance;

import com.alibaba.alink.common.linalg.BLAS;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.MatVecOp;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.Arrays;

/* loaded from: input_file:com/alibaba/alink/operator/common/distance/PearsonDistance.class */
public class PearsonDistance extends CosineDistance {
    private static final long serialVersionUID = 6706414118581906920L;

    @Override // com.alibaba.alink.operator.common.distance.CosineDistance, com.alibaba.alink.operator.common.distance.ContinuousDistance
    public double calc(double[] dArr, double[] dArr2) {
        minusAvg(dArr);
        minusAvg(dArr2);
        double dot = BLAS.dot(dArr, dArr2);
        double sqrt = Math.sqrt(BLAS.dot(dArr, dArr) * BLAS.dot(dArr2, dArr2));
        return 1.0d - (sqrt > Criteria.INVALID_GAIN ? dot / sqrt : Criteria.INVALID_GAIN);
    }

    private static void minusAvg(double[] dArr) {
        double sum = Arrays.stream(dArr).sum() / dArr.length;
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] - sum;
        }
    }

    private static void minusAvg(Vector vector) {
        if (vector instanceof DenseVector) {
            minusAvg(((DenseVector) vector).getData());
        } else {
            minusAvg(((SparseVector) vector).getValues());
        }
    }

    @Override // com.alibaba.alink.operator.common.distance.CosineDistance, com.alibaba.alink.operator.common.distance.ContinuousDistance
    public double calc(Vector vector, Vector vector2) {
        minusAvg(vector);
        minusAvg(vector2);
        double dot = MatVecOp.dot(vector, vector2);
        double normL2 = vector.normL2() * vector2.normL2();
        return 1.0d - (normL2 > Criteria.INVALID_GAIN ? dot / normL2 : Criteria.INVALID_GAIN);
    }

    @Override // com.alibaba.alink.operator.common.distance.CosineDistance, com.alibaba.alink.operator.common.distance.FastDistance
    public void updateLabel(FastDistanceData fastDistanceData) {
        if (fastDistanceData instanceof FastDistanceVectorData) {
            FastDistanceVectorData fastDistanceVectorData = (FastDistanceVectorData) fastDistanceData;
            minusAvg(fastDistanceVectorData.vector);
            double sqrt = Math.sqrt(MatVecOp.dot(fastDistanceVectorData.vector, fastDistanceVectorData.vector));
            if (sqrt > Criteria.INVALID_GAIN) {
                fastDistanceVectorData.vector.scaleEqual(1.0d / sqrt);
                return;
            }
            return;
        }
        if (fastDistanceData instanceof FastDistanceMatrixData) {
            FastDistanceMatrixData fastDistanceMatrixData = (FastDistanceMatrixData) fastDistanceData;
            int numRows = fastDistanceMatrixData.vectors.numRows();
            double[] data = fastDistanceMatrixData.vectors.getData();
            int i = 0;
            while (i < data.length) {
                int i2 = i + numRows;
                double d = 0.0d;
                double d2 = 0.0d;
                while (i < i2) {
                    d2 += data[i];
                    d += data[i] * data[i];
                    i++;
                }
                double d3 = d2 / numRows;
                double sqrt2 = Math.sqrt(d - ((numRows * d3) * d3));
                if (sqrt2 > Criteria.INVALID_GAIN) {
                    for (int i3 = i - numRows; i3 < i; i3++) {
                        data[i3] = (data[i3] - d3) / sqrt2;
                    }
                }
            }
            return;
        }
        FastDistanceSparseData fastDistanceSparseData = (FastDistanceSparseData) fastDistanceData;
        double[] dArr = new double[fastDistanceSparseData.vectorNum];
        double[] dArr2 = new double[fastDistanceSparseData.vectorNum];
        int[] iArr = new int[fastDistanceSparseData.vectorNum];
        int[][] indices = fastDistanceSparseData.getIndices();
        double[][] values = fastDistanceSparseData.getValues();
        for (int i4 = 0; i4 < indices.length; i4++) {
            if (null != indices[i4]) {
                for (int i5 = 0; i5 < indices[i4].length; i5++) {
                    int i6 = indices[i4][i5];
                    dArr2[i6] = dArr2[i6] + values[i4][i5];
                    int i7 = indices[i4][i5];
                    dArr[i7] = dArr[i7] + (values[i4][i5] * values[i4][i5]);
                    int i8 = indices[i4][i5];
                    iArr[i8] = iArr[i8] + 1;
                }
            }
        }
        for (int i9 = 0; i9 < dArr.length; i9++) {
            int i10 = i9;
            dArr2[i10] = dArr2[i10] / iArr[i9];
            dArr[i9] = Math.sqrt(dArr[i9] - ((iArr[i9] * dArr2[i9]) * dArr2[i9]));
        }
        for (int i11 = 0; i11 < indices.length; i11++) {
            if (null != indices[i11]) {
                for (int i12 = 0; i12 < indices[i11].length; i12++) {
                    if (dArr[indices[i11][i12]] > Criteria.INVALID_GAIN) {
                        values[i11][i12] = (values[i11][i12] - dArr2[indices[i11][i12]]) / dArr[indices[i11][i12]];
                    }
                }
            }
        }
    }
}
