package com.alibaba.alink.operator.common.clustering.kmeans;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.CompareCriterionFunction;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.operator.batch.clustering.KMeansTrainBatchOp;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.operator.common.distance.FastDistanceData;
import com.alibaba.alink.operator.common.distance.FastDistanceMatrixData;
import com.alibaba.alink.operator.common.nlp.WordCountUtil;
import org.apache.flink.api.java.tuple.Tuple2;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/kmeans/KMeansIterTermination.class */
public class KMeansIterTermination extends CompareCriterionFunction {
    private static final Logger LOG = LoggerFactory.getLogger(KMeansIterTermination.class);
    private static int MAX_K_NUMBER = WordCountUtil.BOUND_SIZE;
    private static final long serialVersionUID = 6636978614737723605L;
    private FastDistance distance;
    private double tol;
    private transient DenseMatrix distanceMatrix;
    private transient double[] centroid1;
    private transient double[] centroid2;

    public KMeansIterTermination(FastDistance fastDistance, double d) {
        this.distance = fastDistance;
        this.tol = d;
    }

    @Override // com.alibaba.alink.common.comqueue.CompareCriterionFunction
    public boolean calc(ComContext comContext) {
        Integer num = (Integer) comContext.getObj(KMeansTrainBatchOp.K);
        Integer num2 = (Integer) comContext.getObj(KMeansTrainBatchOp.VECTOR_SIZE);
        Tuple2 tuple2 = (Tuple2) comContext.getObj(KMeansTrainBatchOp.CENTROID1);
        Tuple2 tuple22 = (Tuple2) comContext.getObj(KMeansTrainBatchOp.CENTROID2);
        if (num.intValue() <= MAX_K_NUMBER) {
            this.distanceMatrix = this.distance.calc((FastDistanceData) tuple2.f1, (FastDistanceData) tuple22.f1, this.distanceMatrix);
            for (int i = 0; i < num.intValue(); i++) {
                double d = this.distanceMatrix.get(i, i);
                LOG.info("StepNo {}, TaskId {} ||centroid-prev_centroid|| {}", new Object[]{Integer.valueOf(comContext.getStepNo()), Integer.valueOf(comContext.getTaskId()), Double.valueOf(d)});
                if (d >= this.tol) {
                    return false;
                }
            }
            return true;
        }
        double[] data = ((FastDistanceMatrixData) tuple2.f1).getVectors().getData();
        double[] data2 = ((FastDistanceMatrixData) tuple22.f1).getVectors().getData();
        if (null == this.centroid1) {
            this.centroid1 = new double[num2.intValue()];
            this.centroid2 = new double[num2.intValue()];
        }
        for (int i2 = 0; i2 < num.intValue(); i2++) {
            System.arraycopy(data, i2 * num2.intValue(), this.centroid1, 0, num2.intValue());
            System.arraycopy(data2, i2 * num2.intValue(), this.centroid2, 0, num2.intValue());
            double calc = this.distance.calc(this.centroid1, this.centroid2);
            LOG.info("StepNo {}, TaskId {} ||centroid-prev_centroid|| {}", new Object[]{Integer.valueOf(comContext.getStepNo()), Integer.valueOf(comContext.getTaskId()), Double.valueOf(calc)});
            if (calc >= this.tol) {
                return false;
            }
        }
        return true;
    }
}
