package com.alibaba.alink.operator.batch.clustering;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.comqueue.CompareCriterionFunction;
import com.alibaba.alink.common.comqueue.IterativeComQueue;
import com.alibaba.alink.common.comqueue.communication.AllReduce;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.clustering.KMeansModelInfoBatchOp;
import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper;
import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp;
import com.alibaba.alink.operator.common.clustering.kmeans.KMeansAssignCluster;
import com.alibaba.alink.operator.common.clustering.kmeans.KMeansInitCentroids;
import com.alibaba.alink.operator.common.clustering.kmeans.KMeansIterTermination;
import com.alibaba.alink.operator.common.clustering.kmeans.KMeansModelDataConverter;
import com.alibaba.alink.operator.common.clustering.kmeans.KMeansOutputModel;
import com.alibaba.alink.operator.common.clustering.kmeans.KMeansPreallocateCentroid;
import com.alibaba.alink.operator.common.clustering.kmeans.KMeansUpdateCentroids;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.operator.common.distance.FastDistanceMatrixData;
import com.alibaba.alink.operator.common.distance.FastDistanceVectorData;
import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary;
import com.alibaba.alink.params.clustering.KMeansTrainParams;
import com.alibaba.alink.params.shared.clustering.HasKMeansDistanceType;
import com.alibaba.alink.params.shared.clustering.HasKMeansWithHaversineDistanceType;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(value = PortType.MODEL, desc = PortDesc.KMEANS_MODEL)})
@ParamSelectColumnSpec(name = "vectorCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}, allowedTypeCollections = {TypeCollections.VECTOR_TYPES})
@NameCn("K均值聚类训练")
@NameEn("KMeans Training")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.clustering.KMeans")
/* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/KMeansTrainBatchOp.class */
public final class KMeansTrainBatchOp extends BatchOperator<KMeansTrainBatchOp> implements KMeansTrainParams<KMeansTrainBatchOp>, WithModelInfoBatchOp<KMeansModelInfoBatchOp.KMeansModelInfo, KMeansTrainBatchOp, KMeansModelInfoBatchOp> {
    public static final String TRAIN_DATA = "trainData";
    public static final String INIT_CENTROID = "initCentroid";
    public static final String CENTROID1 = "centroid1";
    public static final String CENTROID2 = "centroid2";
    public static final String CENTROID_ALL_REDUCE = "centroidAllReduce";
    public static final String KMEANS_STATISTICS = "statistics";
    public static final String VECTOR_SIZE = "vectorSize";
    public static final String K = "k";
    private static final long serialVersionUID = -1848822118021355321L;

    public KMeansTrainBatchOp() {
        this(null);
    }

    public KMeansTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp
    public KMeansModelInfoBatchOp getModelInfoBatchOp() {
        return new KMeansModelInfoBatchOp(getParams()).linkFrom(this);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DataSet<Row> iterateICQ(DataSet<FastDistanceMatrixData> dataSet, DataSet<FastDistanceVectorData> dataSet2, DataSet<Integer> dataSet3, int i, double d, FastDistance fastDistance, HasKMeansWithHaversineDistanceType.DistanceType distanceType, String str, String str2, String str3) {
        return new IterativeComQueue().initWithPartitionedData("trainData", dataSet2).initWithBroadcastData(INIT_CENTROID, dataSet).initWithBroadcastData(KMEANS_STATISTICS, dataSet3).add(new KMeansPreallocateCentroid()).add(new KMeansAssignCluster(fastDistance)).add(new AllReduce(CENTROID_ALL_REDUCE)).add(new KMeansUpdateCentroids(fastDistance)).setCompareCriterionOfNode0((CompareCriterionFunction) new KMeansIterTermination(fastDistance, d)).closeWith(new KMeansOutputModel(distanceType, str, str2, str3)).setMaxIter(i).exec();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public KMeansTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        int intValue = getMaxIter().intValue();
        double doubleValue = getEpsilon().doubleValue();
        String vectorCol = getVectorCol();
        HasKMeansDistanceType.DistanceType distanceType = getDistanceType();
        final FastDistance fastDistance = distanceType.getFastDistance();
        Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> summaryHelper = StatisticsHelper.summaryHelper(checkAndGetFirst, null, vectorCol);
        MapOperator map = ((DataSet) summaryHelper.f1).map(new MapFunction<BaseVectorSummary, Integer>() { // from class: com.alibaba.alink.operator.batch.clustering.KMeansTrainBatchOp.1
            private static final long serialVersionUID = 4184586558834055401L;

            public Integer map(BaseVectorSummary baseVectorSummary) {
                AkPreconditions.checkArgument(baseVectorSummary.count() > 0, (ExceptionWithErrorCode) new AkIllegalDataException("The train dataset is empty!"));
                return Integer.valueOf(baseVectorSummary.vectorSize());
            }
        });
        SingleInputUdfOperator withBroadcastSet = ((DataSet) summaryHelper.f0).rebalance().map(new RichMapFunction<Vector, FastDistanceVectorData>() { // from class: com.alibaba.alink.operator.batch.clustering.KMeansTrainBatchOp.2
            private static final long serialVersionUID = -7443226889326704768L;
            private int vectorSize;

            public void open(Configuration configuration) {
                this.vectorSize = ((Integer) getRuntimeContext().getBroadcastVariable(KMeansTrainBatchOp.VECTOR_SIZE).get(0)).intValue();
            }

            public FastDistanceVectorData map(Vector vector) {
                if (vector instanceof SparseVector) {
                    ((SparseVector) vector).setSize(this.vectorSize);
                }
                return fastDistance.prepareVectorData(Row.of(new Object[]{vector}), 0, new int[0]);
            }
        }).withBroadcastSet(map, VECTOR_SIZE);
        setOutput(iterateICQ(KMeansInitCentroids.initKmeansCentroids(withBroadcastSet, fastDistance, getParams(), map, getRandomSeed().intValue()), withBroadcastSet, map, intValue, doubleValue, fastDistance, HasKMeansWithHaversineDistanceType.DistanceType.valueOf(distanceType.name()), vectorCol, null, null), new KMeansModelDataConverter().getModelSchema());
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ KMeansTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
