package com.alibaba.alink.operator.common.clustering;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.operator.common.clustering.common.Center;
import com.alibaba.alink.operator.common.clustering.common.Cluster;
import com.alibaba.alink.operator.common.clustering.common.Sample;
import com.alibaba.alink.operator.common.distance.ContinuousDistance;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Random;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/LocalKMeans.class */
public class LocalKMeans {
    public static void clustering(Iterable<Sample> iterable, Collector<Sample> collector, int i, double d, int i2, ContinuousDistance continuousDistance) {
        ArrayList arrayList = new ArrayList();
        Iterator<Sample> it = iterable.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next());
        }
        for (Sample sample : new LocalKMeans().clustering((Sample[]) arrayList.toArray(new Sample[arrayList.size()]), i, d, i2, continuousDistance)) {
            collector.collect(sample);
        }
    }

    public static FindResult findCluster(Center[] centerArr, DenseVector denseVector, ContinuousDistance continuousDistance) {
        long j = -1;
        double d = Double.POSITIVE_INFINITY;
        for (Center center : centerArr) {
            if (null != center) {
                double calc = continuousDistance.calc(denseVector, center.getVector());
                if (calc < d) {
                    j = center.getClusterId();
                    d = calc;
                }
            }
        }
        return new FindResult(Long.valueOf(j), Double.valueOf(d));
    }

    public static Center[] getCentersFromClusters(Sample[] sampleArr, int i) {
        Cluster[] clusterArr = new Cluster[i];
        for (int i2 = 0; i2 < i; i2++) {
            clusterArr[i2] = new Cluster();
        }
        for (int i3 = 0; i3 < sampleArr.length; i3++) {
            clusterArr[(int) sampleArr[i3].getClusterId()].addSample(sampleArr[i3]);
        }
        ArrayList arrayList = new ArrayList();
        for (int i4 = 0; i4 < clusterArr.length; i4++) {
            if (clusterArr[i4].getCenter().getVector() != null) {
                arrayList.add(clusterArr[i4].getCenter());
            }
        }
        for (int i5 = 0; i5 < arrayList.size(); i5++) {
            ((Center) arrayList.get(i5)).setClusterId(i5);
        }
        return (Center[]) arrayList.toArray(new Center[0]);
    }

    public Sample[] clustering(Sample[] sampleArr, int i, double d, int i2, ContinuousDistance continuousDistance) {
        kMeansClustering(sampleArr, d, i > sampleArr.length ? sampleArr.length : i, continuousDistance, i2);
        return sampleArr;
    }

    private Center[] getInitialCenters(Sample[] sampleArr, int i) {
        Center[] centerArr = new Center[i];
        int length = sampleArr.length;
        boolean[] zArr = new boolean[length];
        Arrays.fill(zArr, false);
        if (i < length / 3) {
            int i2 = 0;
            while (i2 < i) {
                int nextInt = new Random().nextInt(length);
                if (!zArr[nextInt]) {
                    centerArr[i2] = new Center(i2, 0L, sampleArr[nextInt].getVector());
                    i2++;
                }
                zArr[nextInt] = true;
            }
        } else {
            for (int i3 = 0; i3 < i; i3++) {
                centerArr[i3] = new Center(i3, 0L, sampleArr[i3].getVector());
            }
        }
        return centerArr;
    }

    private Center[] kMeansClustering(Sample[] sampleArr, double d, int i, ContinuousDistance continuousDistance, int i2) {
        Center[] initialCenters = getInitialCenters(sampleArr, i);
        int i3 = 0;
        double d2 = Criteria.INVALID_GAIN;
        while (true) {
            double d3 = d2;
            int i4 = i3;
            i3++;
            if (i4 >= i2) {
                break;
            }
            double d4 = 0.0d;
            for (int i5 = 0; i5 < sampleArr.length; i5++) {
                FindResult findCluster = findCluster(initialCenters, sampleArr[i5].getVector(), continuousDistance);
                sampleArr[i5].setClusterId(findCluster.getClusterId().longValue());
                d4 += findCluster.getDistance().doubleValue();
            }
            initialCenters = getCentersFromClusters(sampleArr, i);
            if (Math.abs(d4 - d3) / sampleArr.length < d) {
                break;
            }
            d2 = d4;
        }
        return initialCenters;
    }
}
