package com.alibaba.alink.operator.common.clustering.agnes;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.utils.AlinkSerializable;
import com.alibaba.alink.operator.common.distance.ContinuousDistance;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/agnes/Agnes.class */
public class Agnes implements AlinkSerializable {
    private static final Logger LOG = LoggerFactory.getLogger(Agnes.class);

    public static List<AgnesCluster> startAnalysis(List<AgnesSample> list, int i, double d, Linkage linkage, ContinuousDistance continuousDistance) {
        List<AgnesCluster> initialCluster = initialCluster(list);
        int i2 = 1;
        while (initialCluster.size() > i) {
            double d2 = Double.MAX_VALUE;
            int i3 = 0;
            int i4 = 0;
            for (int i5 = 0; i5 < initialCluster.size(); i5++) {
                for (int i6 = 0; i6 < initialCluster.size(); i6++) {
                    if (i5 != i6) {
                        AgnesCluster agnesCluster = initialCluster.get(i5);
                        AgnesCluster agnesCluster2 = initialCluster.get(i6);
                        List<AgnesSample> agnesSamples = agnesCluster.getAgnesSamples();
                        List<AgnesSample> agnesSamples2 = agnesCluster2.getAgnesSamples();
                        switch (linkage) {
                            case MIN:
                                double d3 = Double.MAX_VALUE;
                                for (int i7 = 0; i7 < agnesSamples.size(); i7++) {
                                    for (int i8 = 0; i8 < agnesSamples2.size(); i8++) {
                                        double calc = continuousDistance.calc(agnesSamples.get(i7).getVector(), agnesSamples2.get(i8).getVector());
                                        if (calc < d3) {
                                            d3 = calc;
                                        }
                                    }
                                }
                                if (d3 < d2) {
                                    d2 = d3;
                                    i3 = i5;
                                    i4 = i6;
                                    break;
                                } else {
                                    break;
                                }
                            case MAX:
                                double d4 = Double.MIN_VALUE;
                                for (int i9 = 0; i9 < agnesSamples.size(); i9++) {
                                    for (int i10 = 0; i10 < agnesSamples2.size(); i10++) {
                                        double calc2 = continuousDistance.calc(agnesSamples.get(i9).getVector(), agnesSamples2.get(i10).getVector());
                                        if (calc2 > d4) {
                                            d4 = calc2;
                                        }
                                    }
                                }
                                if (d4 < d2) {
                                    d2 = d4;
                                    i3 = i5;
                                    i4 = i6;
                                    break;
                                } else {
                                    break;
                                }
                            case AVERAGE:
                                double d5 = 0.0d;
                                for (int i11 = 0; i11 < agnesSamples.size(); i11++) {
                                    for (int i12 = 0; i12 < agnesSamples2.size(); i12++) {
                                        d5 += continuousDistance.calc(agnesSamples.get(i11).getVector(), agnesSamples2.get(i12).getVector());
                                    }
                                }
                                double size = d5 / agnesSamples.size();
                                if (size < d2) {
                                    d2 = size;
                                    i3 = i5;
                                    i4 = i6;
                                    break;
                                } else {
                                    break;
                                }
                            case MEAN:
                                double calc3 = continuousDistance.calc(mean(agnesSamples, continuousDistance), mean(agnesSamples2, continuousDistance));
                                if (calc3 < d2) {
                                    d2 = calc3;
                                    i3 = i5;
                                    i4 = i6;
                                    break;
                                } else {
                                    break;
                                }
                            default:
                                throw new RuntimeException("linkage not support:" + linkage);
                        }
                    }
                }
            }
            initialCluster = mergeCluster(initialCluster, i3, i4, i2);
            LOG.info("Iteration:" + i2 + "; distance:" + d2);
            i2++;
            if (d2 > d) {
                return initialCluster;
            }
        }
        return initialCluster;
    }

    private static List<AgnesCluster> mergeCluster(List<AgnesCluster> list, int i, int i2, int i3) {
        if (i != i2) {
            AgnesCluster agnesCluster = list.get(i);
            AgnesCluster agnesCluster2 = list.get(i2);
            List<AgnesSample> agnesSamples = agnesCluster2.getAgnesSamples();
            agnesCluster2.getFirstSample().setParentId(agnesCluster.getFirstSample().getSampleId());
            agnesCluster2.getFirstSample().setMergeIter(i3);
            for (AgnesSample agnesSample : agnesSamples) {
                agnesSample.setClusterId(agnesCluster.getClusterId());
                agnesCluster.addDataPoints(agnesSample);
            }
            list.remove(i2);
        }
        return list;
    }

    private static List<AgnesCluster> initialCluster(List<AgnesSample> list) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            list.get(i).setClusterId(i);
            arrayList.add(new AgnesCluster(i, list.get(i)));
        }
        return arrayList;
    }

    public static DenseVector mean(List<AgnesSample> list, ContinuousDistance continuousDistance) {
        if (null == list || list.size() <= 0) {
            return null;
        }
        DenseVector denseVector = new DenseVector(list.get(0).getVector().size());
        Iterator<AgnesSample> it = list.iterator();
        while (it.hasNext()) {
            denseVector.plusEqual(it.next().getVector());
        }
        denseVector.scaleEqual(1.0d / list.size());
        return denseVector;
    }
}
