package com.alibaba.alink.operator.common.clustering.kmeans;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.linalg.BLAS;
import com.alibaba.alink.operator.batch.clustering.KMeansTrainBatchOp;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.operator.common.distance.FastDistanceMatrixData;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.Arrays;
import org.apache.flink.api.java.tuple.Tuple2;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/kmeans/KMeansUpdateCentroids.class */
public class KMeansUpdateCentroids extends ComputeFunction {
    private static final Logger LOG = LoggerFactory.getLogger(KMeansUpdateCentroids.class);
    private static final long serialVersionUID = -5638042710336233392L;
    private FastDistance distance;

    public KMeansUpdateCentroids(FastDistance fastDistance) {
        this.distance = fastDistance;
    }

    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        LOG.info("StepNo {}, TaskId {} Update cluster begins!", Integer.valueOf(comContext.getStepNo()), Integer.valueOf(comContext.getTaskId()));
        Integer num = (Integer) comContext.getObj(KMeansTrainBatchOp.VECTOR_SIZE);
        Integer num2 = (Integer) comContext.getObj(KMeansTrainBatchOp.K);
        double[] dArr = (double[]) comContext.getObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE);
        Tuple2 tuple2 = comContext.getStepNo() % 2 == 0 ? (Tuple2) comContext.getObj(KMeansTrainBatchOp.CENTROID2) : (Tuple2) comContext.getObj(KMeansTrainBatchOp.CENTROID1);
        tuple2.f0 = Integer.valueOf(comContext.getStepNo());
        comContext.putObj(KMeansTrainBatchOp.K, Integer.valueOf(updateCentroids((FastDistanceMatrixData) tuple2.f1, num2.intValue(), num.intValue(), dArr, this.distance)));
        LOG.info("StepNo {}, TaskId {} Update cluster ends!", Integer.valueOf(comContext.getStepNo()), Integer.valueOf(comContext.getTaskId()));
    }

    static int updateCentroids(FastDistanceMatrixData fastDistanceMatrixData, int i, int i2, double[] dArr, FastDistance fastDistance) {
        int i3 = 0;
        double[] data = fastDistanceMatrixData.getVectors().getData();
        Arrays.fill(data, Criteria.INVALID_GAIN);
        for (int i4 = 0; i4 < i; i4++) {
            int i5 = i4 * (i2 + 1);
            double d = dArr[i5 + i2];
            if (d != Criteria.INVALID_GAIN) {
                BLAS.axpy(i2, 1.0d / d, dArr, i5, data, i3 * i2);
                i3++;
            }
        }
        fastDistance.updateLabel(fastDistanceMatrixData);
        return i3;
    }
}
