package com.alibaba.alink.operator.common.clustering.kmeans;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.operator.batch.clustering.KMeansTrainBatchOp;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.operator.common.distance.FastDistanceMatrixData;
import com.alibaba.alink.operator.common.distance.FastDistanceVectorData;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.Arrays;
import java.util.Iterator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/kmeans/KMeansAssignCluster.class */
public class KMeansAssignCluster extends ComputeFunction {
    private static final Logger LOG = LoggerFactory.getLogger(KMeansAssignCluster.class);
    private static final long serialVersionUID = 1661237257173287045L;
    private FastDistance fastDistance;
    private transient DenseMatrix distanceMatrix;

    public KMeansAssignCluster(FastDistance fastDistance) {
        this.fastDistance = fastDistance;
    }

    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        LOG.info("StepNo {}, TaskId {} Assign cluster begins!", Integer.valueOf(comContext.getStepNo()), Integer.valueOf(comContext.getTaskId()));
        Integer num = (Integer) comContext.getObj(KMeansTrainBatchOp.VECTOR_SIZE);
        Integer num2 = (Integer) comContext.getObj(KMeansTrainBatchOp.K);
        Tuple2 tuple2 = comContext.getStepNo() % 2 == 0 ? (Tuple2) comContext.getObj(KMeansTrainBatchOp.CENTROID1) : (Tuple2) comContext.getObj(KMeansTrainBatchOp.CENTROID2);
        if (null == this.distanceMatrix) {
            this.distanceMatrix = new DenseMatrix(num2.intValue(), 1);
        }
        double[] dArr = (double[]) comContext.getObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE);
        if (dArr == null) {
            dArr = new double[num2.intValue() * (num.intValue() + 1)];
            comContext.putObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE, dArr);
        }
        Iterable iterable = (Iterable) comContext.getObj("trainData");
        if (iterable == null) {
            return;
        }
        Arrays.fill(dArr, Criteria.INVALID_GAIN);
        Iterator it = iterable.iterator();
        while (it.hasNext()) {
            KMeansUtil.updateSumMatrix((FastDistanceVectorData) it.next(), 1L, (FastDistanceMatrixData) tuple2.f1, num.intValue(), dArr, num2.intValue(), this.fastDistance, this.distanceMatrix);
        }
        LOG.info("StepNo {}, TaskId {} Assign cluster ends!", Integer.valueOf(comContext.getStepNo()), Integer.valueOf(comContext.getTaskId()));
    }
}
