package com.alibaba.alink.operator.batch.clustering;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkIllegalStateException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.linalg.BLAS;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.MatVecOp;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.clustering.BisectingKMeansModelInfoBatchOp;
import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper;
import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp;
import com.alibaba.alink.operator.common.clustering.BisectingKMeansModelData;
import com.alibaba.alink.operator.common.clustering.BisectingKMeansModelDataConverter;
import com.alibaba.alink.operator.common.dataproc.FirstReducer;
import com.alibaba.alink.operator.common.distance.ContinuousDistance;
import com.alibaba.alink.operator.common.distance.EuclideanDistance;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.clustering.BisectingKMeansTrainParams;
import com.alibaba.alink.params.shared.clustering.HasKMeansDistanceType;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichJoinFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@ParamSelectColumnSpec(name = "vectorCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}, allowedTypeCollections = {TypeCollections.VECTOR_TYPES})
@NameCn("二分K均值聚类训练")
@NameEn("Bisecting KMeans Training")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.clustering.BisectingKMeans")
/* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/BisectingKMeansTrainBatchOp.class */
public final class BisectingKMeansTrainBatchOp extends BatchOperator<BisectingKMeansTrainBatchOp> implements BisectingKMeansTrainParams<BisectingKMeansTrainBatchOp>, WithModelInfoBatchOp<BisectingKMeansModelInfoBatchOp.BisectingKMeansModelInfo, BisectingKMeansTrainBatchOp, BisectingKMeansModelInfoBatchOp> {
    public static final long ROOT_INDEX = 1;
    private static final Logger LOG = LoggerFactory.getLogger(BisectingKMeansTrainBatchOp.class);
    private static final String VECTOR_SIZE = "vectorSize";
    private static final String DIVISIBLE_INDICES = "divisibleIndices";
    private static final String ITER_INFO = "iterInfo";
    private static final String NEW_CLUSTER_CENTERS = "newClusterCenters";
    private static final long serialVersionUID = 2370661722946764779L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/BisectingKMeansTrainBatchOp$ClusterSummaryAggregator.class */
    public static class ClusterSummaryAggregator implements Serializable {
        private static final long serialVersionUID = 1151203222695380011L;
        private long count;
        private DenseVector sum;
        private double sumSqured;
        private HasKMeansDistanceType.DistanceType distanceType;

        ClusterSummaryAggregator() {
        }

        ClusterSummaryAggregator(int i, HasKMeansDistanceType.DistanceType distanceType) {
            this.sum = new DenseVector(i);
            this.distanceType = distanceType;
        }

        public void add(Vector vector) {
            this.count++;
            double dot = MatVecOp.dot(vector, vector);
            this.sumSqured += dot;
            if (this.distanceType == HasKMeansDistanceType.DistanceType.EUCLIDEAN) {
                BLAS.axpy(1.0d, vector, this.sum);
            } else {
                AkPreconditions.checkArgument(dot > Criteria.INVALID_GAIN, (ExceptionWithErrorCode) new AkIllegalDataException("The L2 norm must not be zero when using cosine distance."));
                BLAS.axpy(1.0d / Math.sqrt(dot), vector, this.sum);
            }
        }

        public void merge(ClusterSummaryAggregator clusterSummaryAggregator) {
            this.count += clusterSummaryAggregator.count;
            this.sumSqured += clusterSummaryAggregator.sumSqured;
            BLAS.axpy(1.0d, clusterSummaryAggregator.sum, this.sum);
        }

        public BisectingKMeansModelData.ClusterSummary toClusterSummary() {
            BisectingKMeansModelData.ClusterSummary clusterSummary = new BisectingKMeansModelData.ClusterSummary();
            if (this.distanceType == HasKMeansDistanceType.DistanceType.EUCLIDEAN) {
                clusterSummary.center = this.sum.scale(1.0d / this.count);
            } else {
                clusterSummary.center = this.sum.scale(1.0d / this.count);
                clusterSummary.center.scaleEqual(1.0d / Math.sqrt(BLAS.dot(clusterSummary.center, clusterSummary.center)));
            }
            clusterSummary.cost = calcClusterCost(this.distanceType, clusterSummary.center, this.sum, this.count, this.sumSqured);
            clusterSummary.size = this.count;
            return clusterSummary;
        }

        private static double calcClusterCost(HasKMeansDistanceType.DistanceType distanceType, DenseVector denseVector, DenseVector denseVector2, long j, double d) {
            if (distanceType == HasKMeansDistanceType.DistanceType.EUCLIDEAN) {
                return Math.max(d - (j * BLAS.dot(denseVector, denseVector)), Criteria.INVALID_GAIN);
            }
            return Math.max(j - (BLAS.dot(denseVector, denseVector2) / Math.sqrt(BLAS.dot(denseVector, denseVector))), Criteria.INVALID_GAIN);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/BisectingKMeansTrainBatchOp$IterInfo.class */
    public static class IterInfo implements Serializable {
        private static final long serialVersionUID = 782571396646278396L;
        public int bisectingStepNo;
        public int innerIterStepNo;
        public int maxIter;
        public boolean isDividing;
        public boolean isNew;
        public boolean shouldStopSplit;

        public IterInfo() {
            this.isDividing = false;
            this.isNew = false;
            this.shouldStopSplit = false;
        }

        IterInfo(int i) {
            this.isDividing = false;
            this.isNew = false;
            this.shouldStopSplit = false;
            this.maxIter = i;
            this.bisectingStepNo = 0;
            this.innerIterStepNo = 0;
        }

        IterInfo(int i, int i2, int i3, boolean z, boolean z2, boolean z3) {
            this.isDividing = false;
            this.isNew = false;
            this.shouldStopSplit = false;
            this.bisectingStepNo = i2;
            this.innerIterStepNo = i3;
            this.maxIter = i;
            this.isDividing = z;
            this.isNew = z2;
            this.shouldStopSplit = z3;
        }

        public String toString() {
            return JsonConverter.toJson(this);
        }

        public void updateIterInfo() {
            this.innerIterStepNo++;
            if (this.innerIterStepNo >= this.maxIter) {
                this.bisectingStepNo++;
                this.innerIterStepNo = 0;
                this.isDividing = false;
                this.isNew = false;
            }
        }

        public boolean doBisectionInStep() {
            return this.innerIterStepNo == 0;
        }

        public boolean atLastInnerIterStep() {
            return this.innerIterStepNo == this.maxIter - 1;
        }

        public boolean atLastBisectionStep() {
            return this.shouldStopSplit;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/BisectingKMeansTrainBatchOp$SaveModel.class */
    public static class SaveModel extends RichMapPartitionFunction<Tuple2<Long, BisectingKMeansModelData.ClusterSummary>, Row> {
        private static final long serialVersionUID = -1963519415361096951L;
        private HasKMeansDistanceType.DistanceType distanceType;
        private String vectorColName;
        private int k;

        SaveModel(HasKMeansDistanceType.DistanceType distanceType, String str, int i) {
            this.distanceType = distanceType;
            this.vectorColName = str;
            this.k = i;
        }

        public void mapPartition(Iterable<Tuple2<Long, BisectingKMeansModelData.ClusterSummary>> iterable, Collector<Row> collector) {
            AkPreconditions.checkArgument(getRuntimeContext().getNumberOfParallelSubtasks() <= 1, "parallelism greater than one when saving model.");
            int intValue = ((Integer) getRuntimeContext().getBroadcastVariable("vectorSize").get(0)).intValue();
            BisectingKMeansModelData bisectingKMeansModelData = new BisectingKMeansModelData();
            bisectingKMeansModelData.summaries = new HashMap(0);
            bisectingKMeansModelData.vectorSize = intValue;
            bisectingKMeansModelData.distanceType = this.distanceType;
            bisectingKMeansModelData.vectorColName = this.vectorColName;
            bisectingKMeansModelData.k = this.k;
            iterable.forEach(tuple2 -> {
            });
            new BisectingKMeansModelDataConverter().save(bisectingKMeansModelData, collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/BisectingKMeansTrainBatchOp$UpdateAssignment.class */
    public static class UpdateAssignment extends RichGroupReduceFunction<Tuple4<Integer, Long, Vector, Long>, Tuple3<Long, Vector, Long>> {
        private static final long serialVersionUID = 7979966548718231315L;
        transient Set<Long> divisibleIndices;
        transient Map<Long, DenseVector> newClusterCenters;
        transient boolean shouldInitState;
        transient boolean shouldUpdateState;
        transient List<Tuple2<Long, Long>> assignmentInState;
        transient Map<Long, Tuple2<DenseVector, Double>> middlePlanes;
        ContinuousDistance distance;

        UpdateAssignment(ContinuousDistance continuousDistance) {
            this.distance = continuousDistance;
        }

        public void open(Configuration configuration) {
            this.divisibleIndices = new HashSet(getRuntimeContext().getBroadcastVariable(BisectingKMeansTrainBatchOp.DIVISIBLE_INDICES));
            this.shouldUpdateState = ((IterInfo) ((Tuple1) getRuntimeContext().getBroadcastVariable(BisectingKMeansTrainBatchOp.ITER_INFO).get(0)).f0).atLastInnerIterStep();
            this.shouldInitState = getIterationRuntimeContext().getSuperstepNumber() == 1;
            List broadcastVariable = getRuntimeContext().getBroadcastVariable(BisectingKMeansTrainBatchOp.NEW_CLUSTER_CENTERS);
            this.newClusterCenters = new HashMap(0);
            broadcastVariable.forEach(tuple2 -> {
            });
            if (this.distance instanceof EuclideanDistance) {
                this.middlePlanes = new HashMap(0);
                this.divisibleIndices.forEach(l -> {
                    long leftChildIndex = BisectingKMeansTrainBatchOp.leftChildIndex(l.longValue());
                    long rightChildIndex = BisectingKMeansTrainBatchOp.rightChildIndex(l.longValue());
                    DenseVector plus = this.newClusterCenters.get(Long.valueOf(rightChildIndex)).plus((Vector) this.newClusterCenters.get(Long.valueOf(leftChildIndex)));
                    DenseVector minus = this.newClusterCenters.get(Long.valueOf(rightChildIndex)).minus((Vector) this.newClusterCenters.get(Long.valueOf(leftChildIndex)));
                    BLAS.scal(0.5d, plus);
                    this.middlePlanes.put(l, Tuple2.of(minus, Double.valueOf(BLAS.dot(plus, minus))));
                });
            }
            if (this.shouldInitState) {
                this.assignmentInState = new ArrayList();
            }
        }

        public void reduce(Iterable<Tuple4<Integer, Long, Vector, Long>> iterable, Collector<Tuple3<Long, Vector, Long>> collector) {
            long closestNode;
            int i = 0;
            for (Tuple4<Integer, Long, Vector, Long> tuple4 : iterable) {
                long longValue = ((Long) tuple4.f3).longValue();
                if (this.shouldInitState) {
                    this.assignmentInState.add(Tuple2.of(tuple4.f1, tuple4.f3));
                } else {
                    if (!((Long) tuple4.f1).equals(this.assignmentInState.get(i).f0)) {
                        throw new AkIllegalStateException("Data out of order.");
                    }
                    longValue = ((Long) this.assignmentInState.get(i).f1).longValue();
                }
                if (this.divisibleIndices.contains(Long.valueOf(longValue))) {
                    long leftChildIndex = BisectingKMeansTrainBatchOp.leftChildIndex(longValue);
                    long rightChildIndex = BisectingKMeansTrainBatchOp.rightChildIndex(longValue);
                    if (this.distance instanceof EuclideanDistance) {
                        Tuple2<DenseVector, Double> tuple2 = this.middlePlanes.get(Long.valueOf(longValue));
                        closestNode = MatVecOp.dot((Vector) tuple4.f2, (Vector) tuple2.f0) < ((Double) tuple2.f1).doubleValue() ? leftChildIndex : rightChildIndex;
                    } else {
                        closestNode = BisectingKMeansTrainBatchOp.getClosestNode(leftChildIndex, this.newClusterCenters.get(Long.valueOf(leftChildIndex)), rightChildIndex, this.newClusterCenters.get(Long.valueOf(rightChildIndex)), (Vector) tuple4.f2, this.distance);
                    }
                    collector.collect(Tuple3.of(tuple4.f1, tuple4.f2, Long.valueOf(closestNode)));
                    if (this.shouldUpdateState) {
                        this.assignmentInState.set(i, Tuple2.of(tuple4.f1, Long.valueOf(closestNode)));
                    }
                }
                i++;
            }
        }
    }

    public BisectingKMeansTrainBatchOp() {
        this(new Params());
    }

    public BisectingKMeansTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp
    public BisectingKMeansModelInfoBatchOp getModelInfoBatchOp() {
        return new BisectingKMeansModelInfoBatchOp(getParams()).linkFrom(this);
    }

    public static long leftChildIndex(long j) {
        return 2 * j;
    }

    public static long rightChildIndex(long j) {
        return (2 * j) + 1;
    }

    private static DataSet<Tuple2<Long, DenseVector>> getNewClusterCenters(DataSet<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>> dataSet) {
        return dataSet.flatMap(new FlatMapFunction<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>, Tuple2<Long, DenseVector>>() { // from class: com.alibaba.alink.operator.batch.clustering.BisectingKMeansTrainBatchOp.1
            private static final long serialVersionUID = 3213199884288120376L;

            public void flatMap(Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo> tuple3, Collector<Tuple2<Long, DenseVector>> collector) {
                if (((IterInfo) tuple3.f2).isNew) {
                    collector.collect(Tuple2.of(tuple3.f0, ((BisectingKMeansModelData.ClusterSummary) tuple3.f1).center));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>) obj, (Collector<Tuple2<Long, DenseVector>>) collector);
            }
        }).name("getNewClusterCenters");
    }

    private static DataSet<Long> getDivisibleClusterIndices(DataSet<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>> dataSet) {
        return dataSet.flatMap(new FlatMapFunction<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>, Long>() { // from class: com.alibaba.alink.operator.batch.clustering.BisectingKMeansTrainBatchOp.2
            private static final long serialVersionUID = -3378609606739051729L;

            public void flatMap(Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo> tuple3, Collector<Long> collector) {
                BisectingKMeansTrainBatchOp.LOG.info("getDivisibleS {}", tuple3);
                if (((IterInfo) tuple3.f2).isDividing) {
                    collector.collect(tuple3.f0);
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>) obj, (Collector<Long>) collector);
            }
        }).name("getDivisibleClusterIndices");
    }

    private static DataSet<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>> getOrSplitClusters(DataSet<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>> dataSet, final int i, final int i2, final int i3) {
        return dataSet.partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.batch.clustering.BisectingKMeansTrainBatchOp.4
            private static final long serialVersionUID = -9210153686004045278L;

            public int partition(Integer num, int i4) {
                return 0;
            }
        }, new KeySelector<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>, Integer>() { // from class: com.alibaba.alink.operator.batch.clustering.BisectingKMeansTrainBatchOp.5
            private static final long serialVersionUID = 1038545655756366007L;

            public Integer getKey(Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo> tuple3) {
                return 0;
            }
        }).mapPartition(new RichMapPartitionFunction<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>, Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>>() { // from class: com.alibaba.alink.operator.batch.clustering.BisectingKMeansTrainBatchOp.3
            private static final long serialVersionUID = 673707294676457023L;
            private transient Random random;
            static final /* synthetic */ boolean $assertionsDisabled;

            public void open(Configuration configuration) {
                if (this.random == null && getRuntimeContext().getIndexOfThisSubtask() == 0) {
                    this.random = new Random(i3);
                }
            }

            public void mapPartition(Iterable<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>> iterable, Collector<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>> collector) {
                if (getRuntimeContext().getIndexOfThisSubtask() > 0) {
                    return;
                }
                ArrayList arrayList = new ArrayList();
                arrayList.getClass();
                iterable.forEach((v1) -> {
                    r1.add(v1);
                });
                if (!((IterInfo) ((Tuple3) arrayList.get(0)).f2).doBisectionInStep()) {
                    collector.getClass();
                    arrayList.forEach((v1) -> {
                        r1.collect(v1);
                    });
                } else {
                    Set findSplitableClusters = BisectingKMeansTrainBatchOp.findSplitableClusters(arrayList, i, i2);
                    boolean z = findSplitableClusters.size() + BisectingKMeansTrainBatchOp.getNumLeaf(arrayList) >= i;
                    arrayList.forEach(tuple3 -> {
                        if (!$assertionsDisabled && ((IterInfo) tuple3.f2).isDividing) {
                            throw new AssertionError();
                        }
                        if (!$assertionsDisabled && ((IterInfo) tuple3.f2).isNew) {
                            throw new AssertionError();
                        }
                        ((IterInfo) tuple3.f2).shouldStopSplit = z;
                        if (!findSplitableClusters.contains(tuple3.f0)) {
                            collector.collect(tuple3);
                            return;
                        }
                        BisectingKMeansModelData.ClusterSummary clusterSummary = (BisectingKMeansModelData.ClusterSummary) tuple3.f1;
                        IterInfo iterInfo = new IterInfo(((IterInfo) tuple3.f2).maxIter, ((IterInfo) tuple3.f2).bisectingStepNo, ((IterInfo) tuple3.f2).innerIterStepNo, false, true, z);
                        Tuple2 initialSplitCenter = BisectingKMeansTrainBatchOp.initialSplitCenter(clusterSummary.center, this.random);
                        BisectingKMeansModelData.ClusterSummary clusterSummary2 = new BisectingKMeansModelData.ClusterSummary();
                        clusterSummary2.center = (DenseVector) initialSplitCenter.f0;
                        BisectingKMeansModelData.ClusterSummary clusterSummary3 = new BisectingKMeansModelData.ClusterSummary();
                        clusterSummary3.center = (DenseVector) initialSplitCenter.f1;
                        ((IterInfo) tuple3.f2).isDividing = true;
                        collector.collect(tuple3);
                        collector.collect(Tuple3.of(Long.valueOf(BisectingKMeansTrainBatchOp.leftChildIndex(((Long) tuple3.f0).longValue())), clusterSummary2, iterInfo));
                        collector.collect(Tuple3.of(Long.valueOf(BisectingKMeansTrainBatchOp.rightChildIndex(((Long) tuple3.f0).longValue())), clusterSummary3, iterInfo));
                    });
                }
            }

            static {
                $assertionsDisabled = !BisectingKMeansTrainBatchOp.class.desiredAssertionStatus();
            }
        }).name("get_or_split_clusters");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Tuple2<DenseVector, DenseVector> initialSplitCenter(DenseVector denseVector, Random random) {
        int size = denseVector.size();
        double sqrt = 1.0E-4d * Math.sqrt(BLAS.dot(denseVector, denseVector));
        DenseVector denseVector2 = new DenseVector(size);
        for (int i = 0; i < size; i++) {
            denseVector2.set(i, sqrt * random.nextDouble());
        }
        return Tuple2.of(denseVector.minus((Vector) denseVector2), denseVector.plus((Vector) denseVector2));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Set<Long> findSplitableClusters(List<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>> list, int i, int i2) {
        HashSet hashSet = new HashSet();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        list.forEach(tuple3 -> {
            hashSet.add(tuple3.f0);
        });
        LOG.info("existingClusterIds {}", JsonConverter.toJson(hashSet));
        list.forEach(tuple32 -> {
            boolean isLeaf = isLeaf(hashSet, ((Long) tuple32.f0).longValue());
            if (isLeaf) {
                arrayList.add(tuple32.f0);
            }
            if (!isLeaf || ((BisectingKMeansModelData.ClusterSummary) tuple32.f1).size <= 1 || ((BisectingKMeansModelData.ClusterSummary) tuple32.f1).size <= i2) {
                return;
            }
            arrayList2.add(tuple32);
        });
        int size = i - arrayList.size();
        ArrayList arrayList3 = new ArrayList();
        arrayList2.sort(new Comparator<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>>() { // from class: com.alibaba.alink.operator.batch.clustering.BisectingKMeansTrainBatchOp.6
            @Override // java.util.Comparator
            public int compare(Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo> tuple33, Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo> tuple34) {
                return -Double.compare(((BisectingKMeansModelData.ClusterSummary) tuple33.f1).cost, ((BisectingKMeansModelData.ClusterSummary) tuple34.f1).cost);
            }
        });
        for (int i3 = 0; i3 < Math.min(size, arrayList2.size()); i3++) {
            arrayList3.add(((Tuple3) arrayList2.get(i3)).f0);
        }
        LOG.info("toSplitClusterIds {}", JsonConverter.toJson(arrayList3));
        return new HashSet(arrayList3);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int getNumLeaf(List<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>> list) {
        HashSet hashSet = new HashSet();
        list.forEach(tuple3 -> {
            hashSet.add(tuple3.f0);
        });
        int i = 0;
        Iterator<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>> it = list.iterator();
        while (it.hasNext()) {
            if (isLeaf(hashSet, ((Long) it.next().f0).longValue())) {
                i++;
            }
        }
        return i;
    }

    private static boolean isLeaf(Set<Long> set, long j) {
        return (set.contains(Long.valueOf(leftChildIndex(j))) || set.contains(Long.valueOf(rightChildIndex(j)))) ? false : true;
    }

    private static DataSet<Tuple3<Long, Vector, Long>> updateAssignment(DataSet<Tuple3<Long, Vector, Long>> dataSet, DataSet<Long> dataSet2, DataSet<Tuple2<Long, DenseVector>> dataSet3, ContinuousDistance continuousDistance, DataSet<Tuple1<IterInfo>> dataSet4) {
        return dataSet.map(new RichMapFunction<Tuple3<Long, Vector, Long>, Tuple4<Integer, Long, Vector, Long>>() { // from class: com.alibaba.alink.operator.batch.clustering.BisectingKMeansTrainBatchOp.8
            private static final long serialVersionUID = 6790217208184538373L;
            private transient int taskId;

            public void open(Configuration configuration) {
                this.taskId = getRuntimeContext().getIndexOfThisSubtask();
            }

            public Tuple4<Integer, Long, Vector, Long> map(Tuple3<Long, Vector, Long> tuple3) {
                return Tuple4.of(Integer.valueOf(this.taskId), tuple3.f0, tuple3.f1, tuple3.f2);
            }
        }).withForwardedFields(new String[]{"f0->f1;f1->f2;f2->f3"}).name("append_partition_id").groupBy(new int[]{0}).sortGroup(1, Order.ASCENDING).withPartitioner(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.batch.clustering.BisectingKMeansTrainBatchOp.7
            private static final long serialVersionUID = -1731024359933958956L;

            public int partition(Integer num, int i) {
                return num.intValue() % i;
            }
        }).reduceGroup(new UpdateAssignment(continuousDistance)).withBroadcastSet(dataSet2, DIVISIBLE_INDICES).withBroadcastSet(dataSet3, NEW_CLUSTER_CENTERS).withBroadcastSet(dataSet4, ITER_INFO).name("update_assignment");
    }

    public static long getClosestNode(long j, DenseVector denseVector, long j2, DenseVector denseVector2, Vector vector, ContinuousDistance continuousDistance) {
        return continuousDistance.calc(vector, denseVector) < continuousDistance.calc(vector, denseVector2) ? j : j2;
    }

    private static DataSet<Tuple2<Long, BisectingKMeansModelData.ClusterSummary>> summary(DataSet<Tuple2<Long, Vector>> dataSet, DataSet<Integer> dataSet2, final HasKMeansDistanceType.DistanceType distanceType) {
        return dataSet.mapPartition(new RichMapPartitionFunction<Tuple2<Long, Vector>, Tuple2<Long, ClusterSummaryAggregator>>() { // from class: com.alibaba.alink.operator.batch.clustering.BisectingKMeansTrainBatchOp.11
            private static final long serialVersionUID = -1690065918903876005L;

            public void mapPartition(Iterable<Tuple2<Long, Vector>> iterable, Collector<Tuple2<Long, ClusterSummaryAggregator>> collector) {
                HashMap hashMap = new HashMap(0);
                int intValue = ((Integer) getRuntimeContext().getBroadcastVariable("vectorSize").get(0)).intValue();
                HasKMeansDistanceType.DistanceType distanceType2 = HasKMeansDistanceType.DistanceType.this;
                iterable.forEach(tuple2 -> {
                    ClusterSummaryAggregator clusterSummaryAggregator = (ClusterSummaryAggregator) hashMap.getOrDefault(tuple2.f0, new ClusterSummaryAggregator(intValue, distanceType2));
                    clusterSummaryAggregator.add((Vector) tuple2.f1);
                    hashMap.putIfAbsent(tuple2.f0, clusterSummaryAggregator);
                });
                hashMap.forEach((l, clusterSummaryAggregator) -> {
                    collector.collect(Tuple2.of(l, clusterSummaryAggregator));
                });
            }
        }).name("local_aggregate_cluster_summary").withBroadcastSet(dataSet2, "vectorSize").groupBy(new int[]{0}).reduce(new ReduceFunction<Tuple2<Long, ClusterSummaryAggregator>>() { // from class: com.alibaba.alink.operator.batch.clustering.BisectingKMeansTrainBatchOp.10
            private static final long serialVersionUID = -674773996335302008L;

            public Tuple2<Long, ClusterSummaryAggregator> reduce(Tuple2<Long, ClusterSummaryAggregator> tuple2, Tuple2<Long, ClusterSummaryAggregator> tuple22) {
                ((ClusterSummaryAggregator) tuple2.f1).merge((ClusterSummaryAggregator) tuple22.f1);
                return tuple2;
            }
        }).name("global_aggregate_cluster_summary").map(new MapFunction<Tuple2<Long, ClusterSummaryAggregator>, Tuple2<Long, BisectingKMeansModelData.ClusterSummary>>() { // from class: com.alibaba.alink.operator.batch.clustering.BisectingKMeansTrainBatchOp.9
            private static final long serialVersionUID = 123946900492166436L;

            public Tuple2<Long, BisectingKMeansModelData.ClusterSummary> map(Tuple2<Long, ClusterSummaryAggregator> tuple2) {
                return Tuple2.of(tuple2.f0, ((ClusterSummaryAggregator) tuple2.f1).toClusterSummary());
            }
        }).withForwardedFields(new String[]{"f0"}).name("make_cluster_summary");
    }

    private static DataSet<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>> updateClusterSummariesAndIterInfo(DataSet<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>> dataSet, DataSet<Tuple2<Long, BisectingKMeansModelData.ClusterSummary>> dataSet2) {
        return dataSet.leftOuterJoin(dataSet2).where(new int[]{0}).equalTo(new int[]{0}).with(new RichJoinFunction<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>, Tuple2<Long, BisectingKMeansModelData.ClusterSummary>, Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>>() { // from class: com.alibaba.alink.operator.batch.clustering.BisectingKMeansTrainBatchOp.12
            private static final long serialVersionUID = -2488476860707642376L;

            public Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo> join(Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo> tuple3, Tuple2<Long, BisectingKMeansModelData.ClusterSummary> tuple2) {
                if (tuple2 == null) {
                    AkPreconditions.checkState(!((IterInfo) tuple3.f2).isNew, "Encounter an empty cluster: {}", tuple3);
                    ((IterInfo) tuple3.f2).updateIterInfo();
                    return tuple3;
                }
                IterInfo iterInfo = (IterInfo) tuple3.f2;
                iterInfo.updateIterInfo();
                return Tuple3.of(tuple2.f0, tuple2.f1, iterInfo);
            }
        }).name("update_model");
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public BisectingKMeansTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        HasKMeansDistanceType.DistanceType distanceType = getDistanceType();
        int intValue = getK().intValue();
        final int intValue2 = getMaxIter().intValue();
        String vectorCol = getVectorCol();
        int intValue3 = getMinDivisibleClusterSize().intValue();
        FastDistance fastDistance = distanceType.getFastDistance();
        Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> summaryHelper = StatisticsHelper.summaryHelper(checkAndGetFirst, null, vectorCol);
        MapOperator map = ((DataSet) summaryHelper.f1).map(new MapFunction<BaseVectorSummary, Integer>() { // from class: com.alibaba.alink.operator.batch.clustering.BisectingKMeansTrainBatchOp.13
            private static final long serialVersionUID = 5358843841535961680L;

            public Integer map(BaseVectorSummary baseVectorSummary) {
                AkPreconditions.checkArgument(baseVectorSummary.count() > 0, (ExceptionWithErrorCode) new AkIllegalDataException("The train dataset is empty!"));
                return Integer.valueOf(baseVectorSummary.vectorSize());
            }
        });
        SingleInputUdfOperator withBroadcastSet = DataSetUtils.zipWithUniqueId((DataSet) summaryHelper.f0).map(new RichMapFunction<Tuple2<Long, Vector>, Tuple3<Long, Vector, Long>>() { // from class: com.alibaba.alink.operator.batch.clustering.BisectingKMeansTrainBatchOp.14
            private static final long serialVersionUID = -6036596630416015773L;
            private int vectorSize;

            public void open(Configuration configuration) {
                this.vectorSize = ((Integer) getRuntimeContext().getBroadcastVariable("vectorSize").get(0)).intValue();
            }

            public Tuple3<Long, Vector, Long> map(Tuple2<Long, Vector> tuple2) {
                if (tuple2.f1 instanceof SparseVector) {
                    ((SparseVector) tuple2.f1).setSize(this.vectorSize);
                }
                return Tuple3.of(tuple2.f0, tuple2.f1, 1L);
            }
        }).withBroadcastSet(map, "vectorSize");
        IterativeDataSet iterate = summary(withBroadcastSet.project(new int[]{2, 1}), map, distanceType).map(new MapFunction<Tuple2<Long, BisectingKMeansModelData.ClusterSummary>, Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>>() { // from class: com.alibaba.alink.operator.batch.clustering.BisectingKMeansTrainBatchOp.15
            private static final long serialVersionUID = -3883958936263294331L;

            public Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo> map(Tuple2<Long, BisectingKMeansModelData.ClusterSummary> tuple2) {
                return Tuple3.of(tuple2.f0, tuple2.f1, new IterInfo(intValue2));
            }
        }).withForwardedFields(new String[]{"f0;f1"}).iterate(Integer.MAX_VALUE);
        GroupReduceOperator reduceGroup = iterate.project(new int[]{2}).reduceGroup(new FirstReducer(1));
        DataSet<Tuple3<Long, BisectingKMeansModelData.ClusterSummary, IterInfo>> orSplitClusters = getOrSplitClusters(iterate, intValue, intValue3, getRandomSeed().intValue());
        setOutput((DataSet<Row>) iterate.closeWith(updateClusterSummariesAndIterInfo(orSplitClusters, summary(updateAssignment(withBroadcastSet, getDivisibleClusterIndices(orSplitClusters), getNewClusterCenters(orSplitClusters), fastDistance, reduceGroup).project(new int[]{2, 1}), map, distanceType)), reduceGroup.flatMap(new FlatMapFunction<Tuple1<IterInfo>, Integer>() { // from class: com.alibaba.alink.operator.batch.clustering.BisectingKMeansTrainBatchOp.16
            private static final long serialVersionUID = -4258243788034193744L;

            public void flatMap(Tuple1<IterInfo> tuple1, Collector<Integer> collector) {
                if (((IterInfo) tuple1.f0).atLastInnerIterStep() && ((IterInfo) tuple1.f0).atLastBisectionStep()) {
                    return;
                }
                collector.collect(0);
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple1<IterInfo>) obj, (Collector<Integer>) collector);
            }
        })).project(new int[]{0, 1}).mapPartition(new SaveModel(distanceType, vectorCol, intValue)).withBroadcastSet(map, "vectorSize").setParallelism(1), new BisectingKMeansModelDataConverter().getModelSchema());
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ BisectingKMeansTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
