package com.alibaba.alink.operator.common.clustering.kmeans;

import com.alibaba.alink.common.linalg.BLAS;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.MatVecOp;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.operator.common.distance.FastDistanceData;
import com.alibaba.alink.operator.common.distance.FastDistanceMatrixData;
import com.alibaba.alink.operator.common.distance.FastDistanceVectorData;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/kmeans/LocalKmeansFunc.class */
class LocalKmeansFunc implements Serializable {
    private static final long serialVersionUID = 480799516701310627L;
    private static final Logger LOG = LoggerFactory.getLogger(LocalKmeansFunc.class);
    private static int LOCAL_MAX_ITER = 30;

    LocalKmeansFunc() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static FastDistanceMatrixData kmeans(int i, long[] jArr, FastDistanceVectorData[] fastDistanceVectorDataArr, FastDistance fastDistance, int i2, int i3) {
        Random random = new Random(i3);
        FastDistanceMatrixData buildCentroidsMatrix = KMeansUtil.buildCentroidsMatrix(sampleInitialCentroids(i, jArr, fastDistanceVectorDataArr, fastDistance, random), fastDistance, i2);
        boolean z = false;
        int i4 = 0;
        DenseMatrix denseMatrix = new DenseMatrix(i2 + 1, i);
        DenseMatrix denseMatrix2 = new DenseMatrix(i, 1);
        double[] data = denseMatrix.getData();
        double[] data2 = buildCentroidsMatrix.getVectors().getData();
        int[] iArr = new int[fastDistanceVectorDataArr.length];
        while (!z && i4 < LOCAL_MAX_ITER) {
            i4++;
            z = true;
            for (int i5 = 0; i5 < fastDistanceVectorDataArr.length; i5++) {
                int updateSumMatrix = KMeansUtil.updateSumMatrix(fastDistanceVectorDataArr[i5], jArr[i5], buildCentroidsMatrix, i2, data, i, fastDistance, denseMatrix2);
                if (updateSumMatrix != iArr[i5]) {
                    iArr[i5] = updateSumMatrix;
                    z = false;
                }
            }
            Arrays.fill(data2, Criteria.INVALID_GAIN);
            for (int i6 = 0; i6 < i; i6++) {
                int i7 = i6 * i2;
                int i8 = i7 + i6;
                double d = data[i8 + i2];
                if (d > Criteria.INVALID_GAIN) {
                    BLAS.axpy(i2, 1.0d / d, data, i8, data2, i7);
                } else {
                    MatVecOp.appendVectorToMatrix(buildCentroidsMatrix.getVectors(), false, i6, fastDistanceVectorDataArr[random.nextInt(fastDistanceVectorDataArr.length)].getVector());
                }
                fastDistance.updateLabel(buildCentroidsMatrix);
            }
        }
        if (i4 != LOCAL_MAX_ITER) {
            LOG.info("Local kmeans converge with {} steps.", Integer.valueOf(i4));
        } else {
            LOG.info("Local kmeans reach max iteration number!");
        }
        return buildCentroidsMatrix;
    }

    private static List<FastDistanceVectorData> sampleInitialCentroids(int i, long[] jArr, FastDistanceVectorData[] fastDistanceVectorDataArr, FastDistance fastDistance, Random random) {
        ArrayList arrayList = new ArrayList(i);
        double[] dArr = new double[fastDistanceVectorDataArr.length];
        Arrays.fill(dArr, 1.0d);
        int i2 = 0;
        int i3 = 0;
        while (i3 < i) {
            for (int i4 = 0; i4 < fastDistanceVectorDataArr.length && i3 > 0; i4++) {
                double d = fastDistance.calc((FastDistanceData) fastDistanceVectorDataArr[i4], (FastDistanceData) fastDistanceVectorDataArr[i2]).get(0, 0);
                dArr[i4] = i3 > 1 ? Math.min(d, dArr[i4]) : d;
            }
            i2 = pickWeight(jArr, dArr, random);
            arrayList.add(fastDistanceVectorDataArr[i2]);
            i3++;
        }
        return arrayList;
    }

    private static int pickWeight(long[] jArr, double[] dArr, Random random) {
        int length = dArr.length;
        double[] dArr2 = new double[jArr.length + 1];
        for (int i = 1; i < dArr2.length; i++) {
            dArr2[i] = dArr2[i - 1] + (jArr[i - 1] * dArr[i - 1]);
        }
        double nextDouble = random.nextDouble() * dArr2[dArr2.length - 1];
        int i2 = 0;
        while (true) {
            if (i2 >= dArr2.length) {
                break;
            }
            if (dArr2[i2] >= nextDouble) {
                length = i2;
                break;
            }
            i2++;
        }
        if (length == 0) {
            return 0;
        }
        return length - 1;
    }
}
