package com.alibaba.alink.operator.batch.clustering;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.comqueue.IterTaskObjKeeper;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.clustering.common.Sample;
import com.alibaba.alink.operator.common.distance.ContinuousDistance;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.operator.common.evaluation.EvaluationUtil;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.clustering.GroupKMeansParams;
import com.alibaba.alink.params.shared.clustering.HasClusteringDistanceType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.NumberSequenceIterator;
import org.apache.flink.util.Preconditions;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "featureCols", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "groupCols", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "idCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR})})
@NameCn("分组Kmeans")
@NameEn("Group Kmeans")
/* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupKMeansBatchOp.class */
public final class GroupKMeansBatchOp extends BatchOperator<GroupKMeansBatchOp> implements GroupKMeansParams<GroupKMeansBatchOp> {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupKMeansBatchOp$CacheInitModel.class */
    public static class CacheInitModel extends RichMapPartitionFunction<Tuple3<Integer, String, double[][]>, Object> {
        private final long cacheModelHandler;
        private final long lossHandler;
        private final long partitionInfoHandle;

        public CacheInitModel(long j, long j2, long j3) {
            this.cacheModelHandler = j;
            this.partitionInfoHandle = j2;
            this.lossHandler = j3;
        }

        public void mapPartition(Iterable<Tuple3<Integer, String, double[][]>> iterable, Collector<Object> collector) throws Exception {
            int superstepNumber = getIterationRuntimeContext().getSuperstepNumber();
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            if (superstepNumber == 2) {
                Map map = (Map) IterTaskObjKeeper.get(this.partitionInfoHandle, indexOfThisSubtask);
                Preconditions.checkNotNull(map);
                List groupNames = GroupKMeansBatchOp.getGroupNames(map, indexOfThisSubtask);
                HashMap hashMap = new HashMap();
                HashMap hashMap2 = new HashMap();
                for (Tuple3<Integer, String, double[][]> tuple3 : iterable) {
                    hashMap.put(tuple3.f1, tuple3.f2);
                    hashMap2.put(tuple3.f1, Double.valueOf(Criteria.INVALID_GAIN));
                }
                if (hashMap.size() != groupNames.size()) {
                    throw new RuntimeException("Illegal model size.");
                }
                IterTaskObjKeeper.put(this.cacheModelHandler, indexOfThisSubtask, hashMap);
                IterTaskObjKeeper.put(this.lossHandler, indexOfThisSubtask, hashMap2);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupKMeansBatchOp$CacheSamplesAndGenInitModel.class */
    public static class CacheSamplesAndGenInitModel extends RichMapPartitionFunction<Tuple2<Integer, Sample>, Tuple3<Integer, String, double[][]>> {
        private final long cacheDataHandle;
        private final long partitionInfoHandle;
        private final int numClusters;

        public CacheSamplesAndGenInitModel(long j, long j2, int i) {
            this.cacheDataHandle = j;
            this.partitionInfoHandle = j2;
            this.numClusters = i;
        }

        public void mapPartition(Iterable<Tuple2<Integer, Sample>> iterable, Collector<Tuple3<Integer, String, double[][]>> collector) throws Exception {
            int superstepNumber = getIterationRuntimeContext().getSuperstepNumber();
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            if (superstepNumber == 2) {
                Map map = (Map) IterTaskObjKeeper.get(this.partitionInfoHandle, indexOfThisSubtask);
                Preconditions.checkNotNull(map);
                List<String> groupNames = GroupKMeansBatchOp.getGroupNames(map, indexOfThisSubtask);
                HashMap hashMap = new HashMap();
                Iterator it = groupNames.iterator();
                while (it.hasNext()) {
                    hashMap.put((String) it.next(), new ArrayList());
                }
                for (Tuple2<Integer, Sample> tuple2 : iterable) {
                    ((ArrayList) hashMap.get(((Sample) tuple2.f1).getGroupColNamesString())).add(tuple2.f1);
                }
                HashMap hashMap2 = new HashMap();
                for (Map.Entry entry : hashMap.entrySet()) {
                    hashMap2.put(entry.getKey(), ((ArrayList) entry.getValue()).toArray(new Sample[0]));
                }
                IterTaskObjKeeper.put(this.cacheDataHandle, indexOfThisSubtask, hashMap2);
                for (String str : groupNames) {
                    if (((int[]) map.get(str))[0] == indexOfThisSubtask) {
                        Sample[] sampleArr = (Sample[]) hashMap2.get(str);
                        int min = Math.min(this.numClusters, sampleArr.length);
                        int length = sampleArr[0].getVector().getData().length;
                        double[][] dArr = new double[min][length];
                        for (int i = 0; i < min; i++) {
                            System.arraycopy(sampleArr[i].getVector().getData(), 0, dArr[i], 0, length);
                        }
                        for (int i2 : (int[]) map.get(str)) {
                            collector.collect(Tuple3.of(Integer.valueOf(i2), str, dArr));
                        }
                    }
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupKMeansBatchOp$ComputeUpdates.class */
    public static class ComputeUpdates extends RichMapPartitionFunction<Object, Tuple3<Integer, String, double[][]>> {
        private final long cacheDataHandle;
        private final long cacheModelHandle;
        private final ContinuousDistance distance;
        private final long partitionInfoHandle;

        public ComputeUpdates(long j, long j2, long j3, ContinuousDistance continuousDistance) {
            this.cacheDataHandle = j;
            this.cacheModelHandle = j2;
            this.partitionInfoHandle = j3;
            this.distance = continuousDistance;
        }

        public void mapPartition(Iterable<Object> iterable, Collector<Tuple3<Integer, String, double[][]>> collector) throws Exception {
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
                return;
            }
            Map map = (Map) IterTaskObjKeeper.get(this.cacheModelHandle, indexOfThisSubtask);
            Map map2 = (Map) IterTaskObjKeeper.get(this.cacheDataHandle, indexOfThisSubtask);
            Map map3 = (Map) IterTaskObjKeeper.get(this.partitionInfoHandle, indexOfThisSubtask);
            Preconditions.checkNotNull(map2);
            Preconditions.checkNotNull(map);
            Preconditions.checkNotNull(map3);
            for (String str : (String[]) map2.keySet().toArray(new String[0])) {
                double[][] dArr = (double[][]) map.get(str);
                Sample[] sampleArr = (Sample[]) map2.get(str);
                int length = dArr[0].length;
                double[][] dArr2 = new double[dArr.length][length + 2];
                for (Sample sample : sampleArr) {
                    Tuple2 findClosestCluster = GroupKMeansBatchOp.findClosestCluster(sample.getVector().getData(), dArr, this.distance);
                    sample.setClusterId(((Integer) findClosestCluster.f0).intValue());
                    double[] data = sample.getVector().getData();
                    for (int i = 0; i < data.length; i++) {
                        double[] dArr3 = dArr2[((Integer) findClosestCluster.f0).intValue()];
                        int i2 = i;
                        dArr3[i2] = dArr3[i2] + data[i];
                    }
                    double[] dArr4 = dArr2[((Integer) findClosestCluster.f0).intValue()];
                    dArr4[length] = dArr4[length] + 1.0d;
                    double[] dArr5 = dArr2[((Integer) findClosestCluster.f0).intValue()];
                    int i3 = length + 1;
                    dArr5[i3] = dArr5[i3] + ((Double) findClosestCluster.f1).doubleValue();
                }
                for (int i4 : (int[]) map3.get(str)) {
                    collector.collect(Tuple3.of(Integer.valueOf(i4), str, dArr2));
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupKMeansBatchOp$ComputingGroupSizes.class */
    public static class ComputingGroupSizes implements GroupReduceFunction<Sample, Tuple2<String, Long>> {
        private ComputingGroupSizes() {
        }

        public void reduce(Iterable<Sample> iterable, Collector<Tuple2<String, Long>> collector) throws Exception {
            String str = null;
            long j = 0;
            Iterator<Sample> it = iterable.iterator();
            if (it.hasNext()) {
                str = it.next().getGroupColNamesString();
                j = 0 + 1;
            }
            while (it.hasNext()) {
                j++;
                it.next();
            }
            collector.collect(Tuple2.of(str, Long.valueOf(j)));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupKMeansBatchOp$GroupNameKeySelector.class */
    public static class GroupNameKeySelector implements KeySelector<Sample, String> {
        private GroupNameKeySelector() {
        }

        public String getKey(Sample sample) throws Exception {
            return sample.getGroupColNamesString();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupKMeansBatchOp$HashPartitioner.class */
    public static class HashPartitioner implements Partitioner<Integer> {
        private HashPartitioner() {
        }

        public int partition(Integer num, int i) {
            return num.intValue() % i;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupKMeansBatchOp$MapLongToObject.class */
    public static class MapLongToObject implements MapFunction<Long, Object> {
        private MapLongToObject() {
        }

        public Object map(Long l) throws Exception {
            return new Object();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupKMeansBatchOp$MapRowToSample.class */
    public static class MapRowToSample implements MapFunction<Row, Sample> {
        private final int[] groupNameIndices;
        private final int idColIndex;
        private final int[] featureColIndices;

        public MapRowToSample(int[] iArr, int i, int[] iArr2) {
            this.groupNameIndices = iArr;
            this.idColIndex = i;
            this.featureColIndices = iArr2;
        }

        public Sample map(Row row) throws Exception {
            String[] strArr = new String[this.groupNameIndices.length];
            for (int i = 0; i < strArr.length; i++) {
                Object field = row.getField(this.groupNameIndices[i]);
                Preconditions.checkNotNull(field, "There is NULL value in group col!");
                strArr[i] = field.toString();
            }
            double[] dArr = new double[this.featureColIndices.length];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                Object field2 = row.getField(this.featureColIndices[i2]);
                Preconditions.checkNotNull(field2, "There is NULL value in feature col!");
                dArr[i2] = ((Number) field2).doubleValue();
            }
            Object field3 = row.getField(this.idColIndex);
            Preconditions.checkNotNull(field3, "There is NULL value in id col!");
            return new Sample(field3.toString(), new DenseVector(dArr), -1L, strArr);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupKMeansBatchOp$MapSampleToRow.class */
    public static class MapSampleToRow implements MapFunction<Sample, Row> {
        private final int groupColNamesSize;
        private final TypeInformation<?>[] outputTypes;

        public MapSampleToRow(int i, TypeInformation<?>[] typeInformationArr) {
            this.groupColNamesSize = i;
            this.outputTypes = typeInformationArr;
        }

        public Row map(Sample sample) throws Exception {
            Row row = new Row(this.groupColNamesSize + 2);
            for (int i = 0; i < this.groupColNamesSize; i++) {
                row.setField(i, EvaluationUtil.castTo(sample.getGroupColNames()[i], this.outputTypes[i]));
            }
            row.setField(this.groupColNamesSize, EvaluationUtil.castTo(sample.getSampleId(), this.outputTypes[this.groupColNamesSize]));
            row.setField(this.groupColNamesSize + 1, Long.valueOf(sample.getClusterId()));
            return row;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupKMeansBatchOp$OutputDataSamples.class */
    public static class OutputDataSamples extends RichMapPartitionFunction<Object, Sample> {
        private final long cacheDataHandle;
        private final long partitionInfoHandle;
        private final long cacheModelHandle;
        private final long lossHandle;

        public OutputDataSamples(long j, long j2, long j3, long j4) {
            this.cacheDataHandle = j;
            this.partitionInfoHandle = j2;
            this.cacheModelHandle = j3;
            this.lossHandle = j4;
        }

        public void mapPartition(Iterable<Object> iterable, Collector<Sample> collector) throws Exception {
            int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
            Map map = null;
            for (int i = 0; i < numberOfParallelSubtasks; i++) {
                map = (Map) IterTaskObjKeeper.containsAndRemoves(this.cacheDataHandle, i);
                if (map != null) {
                    break;
                }
            }
            Preconditions.checkNotNull(map);
            for (Sample[] sampleArr : map.values()) {
                for (Sample sample : sampleArr) {
                    collector.collect(sample);
                }
            }
            for (int i2 = 0; i2 < numberOfParallelSubtasks; i2++) {
                IterTaskObjKeeper.remove(this.cacheModelHandle, i2);
                IterTaskObjKeeper.remove(this.partitionInfoHandle, i2);
                IterTaskObjKeeper.remove(this.lossHandle, i2);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupKMeansBatchOp$RebalanceDataAndCachePartitionInfo.class */
    public static class RebalanceDataAndCachePartitionInfo extends RichMapPartitionFunction<Sample, Tuple2<Integer, Sample>> {
        private final String broadcastGroupSizeKey;
        private final long partitionInfoHandle;

        public RebalanceDataAndCachePartitionInfo(String str, long j) {
            this.broadcastGroupSizeKey = str;
            this.partitionInfoHandle = j;
        }

        public void mapPartition(Iterable<Sample> iterable, Collector<Tuple2<Integer, Sample>> collector) throws Exception {
            int superstepNumber = getIterationRuntimeContext().getSuperstepNumber();
            if (superstepNumber == 1) {
                List<Tuple2> broadcastVariable = getRuntimeContext().getBroadcastVariable(this.broadcastGroupSizeKey);
                HashMap hashMap = new HashMap(broadcastVariable.size());
                for (Tuple2 tuple2 : broadcastVariable) {
                    hashMap.put(tuple2.f0, tuple2.f1);
                }
                IterTaskObjKeeper.put(this.partitionInfoHandle, getRuntimeContext().getIndexOfThisSubtask(), GroupKMeansBatchOp.getPartitionInfo(hashMap, getRuntimeContext().getNumberOfParallelSubtasks()));
                return;
            }
            if (superstepNumber == 2) {
                Map map = (Map) IterTaskObjKeeper.get(this.partitionInfoHandle, getRuntimeContext().getIndexOfThisSubtask());
                HashMap hashMap2 = new HashMap(map.size());
                Iterator it = map.keySet().iterator();
                while (it.hasNext()) {
                    hashMap2.put((String) it.next(), -1);
                }
                for (Sample sample : iterable) {
                    String groupColNamesString = sample.getGroupColNamesString();
                    int[] iArr = (int[]) map.get(groupColNamesString);
                    collector.collect(Tuple2.of(Integer.valueOf(iArr[((Integer) hashMap2.compute(groupColNamesString, (str, num) -> {
                        return Integer.valueOf((num.intValue() + 1) % iArr.length);
                    })).intValue()]), sample));
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupKMeansBatchOp$UpdateModel.class */
    public static class UpdateModel extends RichMapPartitionFunction<Tuple3<Integer, String, double[][]>, Object> {
        private final long cacheModelHandle;
        private final long lossHandle;
        private final double tol;

        public UpdateModel(long j, long j2, double d) {
            this.cacheModelHandle = j;
            this.lossHandle = j2;
            this.tol = d;
        }

        public void mapPartition(Iterable<Tuple3<Integer, String, double[][]>> iterable, Collector<Object> collector) throws Exception {
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
                collector.collect(new Object());
                return;
            }
            HashMap hashMap = new HashMap();
            for (Tuple3<Integer, String, double[][]> tuple3 : iterable) {
                String str = (String) tuple3.f1;
                double[][] dArr = (double[][]) tuple3.f2;
                if (hashMap.containsKey(str)) {
                    double[][] dArr2 = (double[][]) hashMap.get(str);
                    for (int i = 0; i < dArr2.length; i++) {
                        for (int i2 = 0; i2 < dArr2[0].length; i2++) {
                            double[] dArr3 = dArr2[i];
                            int i3 = i2;
                            dArr3[i3] = dArr3[i3] + dArr[i][i2];
                        }
                    }
                } else {
                    hashMap.put(str, dArr);
                }
            }
            boolean z = true;
            Map map = (Map) IterTaskObjKeeper.get(this.cacheModelHandle, indexOfThisSubtask);
            Map map2 = (Map) IterTaskObjKeeper.get(this.lossHandle, indexOfThisSubtask);
            Preconditions.checkNotNull(map2);
            Preconditions.checkNotNull(map);
            for (Map.Entry entry : hashMap.entrySet()) {
                String str2 = (String) entry.getKey();
                double[][] dArr4 = (double[][]) entry.getValue();
                double[][] dArr5 = (double[][]) map.get(str2);
                long j = 0;
                double d = 0.0d;
                for (int i4 = 0; i4 < dArr4.length; i4++) {
                    double[] dArr6 = dArr4[i4];
                    j = (long) (j + dArr6[dArr6.length - 2]);
                    d += dArr6[dArr6.length - 1];
                    for (int i5 = 0; i5 < dArr6.length - 2; i5++) {
                        int i6 = i5;
                        dArr6[i6] = dArr6[i6] / dArr6[dArr6.length - 2];
                    }
                    System.arraycopy(dArr6, 0, dArr5[i4], 0, dArr5[i4].length);
                }
                double doubleValue = ((Double) map2.get(str2)).doubleValue();
                double d2 = d / j;
                map2.put(str2, Double.valueOf(d2));
                if (Math.abs(d2 - doubleValue) > this.tol) {
                    z = false;
                }
            }
            if (z) {
                return;
            }
            collector.collect(new Object());
        }
    }

    public GroupKMeansBatchOp() {
        super(null);
    }

    public GroupKMeansBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public GroupKMeansBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String[] featureCols = getFeatureCols();
        int intValue = getK().intValue();
        double doubleValue = getEpsilon().doubleValue();
        int intValue2 = getMaxIter().intValue();
        HasClusteringDistanceType.DistanceType distanceType = getDistanceType();
        String[] groupCols = getGroupCols();
        String idCol = getIdCol();
        FastDistance fastDistance = distanceType.getFastDistance();
        if (featureCols == null || featureCols.length == 0) {
            throw new RuntimeException("featureColNames should be set !");
        }
        for (String str : groupCols) {
            if (TableUtil.findColIndex(featureCols, str) >= 0) {
                throw new RuntimeException("groupColNames should NOT be included in featureColNames!");
            }
        }
        if (null == idCol || "".equals(idCol)) {
            throw new RuntimeException("idCol column should be set!");
        }
        if (TableUtil.findColIndex(featureCols, idCol) >= 0) {
            throw new RuntimeException("idCol column should NOT be included in featureColNames !");
        }
        if (TableUtil.findColIndex(groupCols, idCol) >= 0) {
            throw new RuntimeException("idCol column should NOT be included in groupColNames !");
        }
        String[] strArr = (String[]) ArrayUtils.addAll(groupCols, new String[]{idCol, getPredictionCol()});
        TypeInformation[] typeInformationArr = (TypeInformation[]) ArrayUtils.addAll(TableUtil.findColTypesWithAssertAndHint(checkAndGetFirst.getSchema(), groupCols), new TypeInformation[]{TableUtil.findColTypeWithAssertAndHint(checkAndGetFirst.getSchema(), idCol), Types.LONG});
        String[] colNames = checkAndGetFirst.getColNames();
        MapOperator map = checkAndGetFirst.getDataSet().map(new MapRowToSample(TableUtil.findColIndices(colNames, groupCols), TableUtil.findColIndex(colNames, idCol), TableUtil.findColIndices(colNames, featureCols)));
        GroupReduceOperator reduceGroup = map.groupBy(new GroupNameKeySelector()).reduceGroup(new ComputingGroupSizes());
        long newHandle = IterTaskObjKeeper.getNewHandle();
        long newHandle2 = IterTaskObjKeeper.getNewHandle();
        long newHandle3 = IterTaskObjKeeper.getNewHandle();
        long newHandle4 = IterTaskObjKeeper.getNewHandle();
        IterativeDataSet iterate = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromParallelCollection(new NumberSequenceIterator(1L, 2L), BasicTypeInfo.LONG_TYPE_INFO).map(new MapLongToObject()).iterate(intValue2);
        Operator name = iterate.mapPartition(new ComputeUpdates(newHandle2, newHandle3, newHandle, fastDistance)).withBroadcastSet(map.mapPartition(new RebalanceDataAndCachePartitionInfo("groupAndSizeKey", newHandle)).withBroadcastSet(iterate, "loopStart").withBroadcastSet(reduceGroup, "groupAndSizeKey").name("rebalanceData").partitionCustom(new HashPartitioner(), 0).mapPartition(new CacheSamplesAndGenInitModel(newHandle2, newHandle, intValue)).name("cacheTrainingDataAndGetInitModel").partitionCustom(new HashPartitioner(), 0).mapPartition(new CacheInitModel(newHandle3, newHandle, newHandle4)).name("cacheInitModel"), "cacheInitModel").name("computeUpdates").partitionCustom(new HashPartitioner(), 0).mapPartition(new UpdateModel(newHandle3, newHandle4, doubleValue)).name("updateModel");
        DataSet closeWith = iterate.closeWith(name, name);
        setOutput((DataSet<Row>) closeWith.mapPartition(new OutputDataSamples(newHandle2, newHandle, newHandle3, newHandle4)).withBroadcastSet(closeWith, "iterationEnd").name("outputDataSamples").map(new MapSampleToRow(groupCols.length, typeInformationArr)), new TableSchema(strArr, typeInformationArr));
        return this;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Tuple2<Integer, Double> findClosestCluster(double[] dArr, double[][] dArr2, ContinuousDistance continuousDistance) {
        double d = Double.MAX_VALUE;
        int i = Integer.MAX_VALUE;
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            double calc = continuousDistance.calc(dArr, dArr2[i2]);
            if (calc < d) {
                d = calc;
                i = i2;
            }
        }
        return Tuple2.of(Integer.valueOf(i), Double.valueOf(d));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static List<String> getGroupNames(Map<String, int[]> map, int i) {
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<String, int[]> entry : map.entrySet()) {
            int[] value = entry.getValue();
            int length = value.length;
            int i2 = 0;
            while (true) {
                if (i2 >= length) {
                    break;
                }
                if (value[i2] == i) {
                    arrayList.add(entry.getKey());
                    break;
                }
                i2++;
            }
        }
        return arrayList;
    }

    @VisibleForTesting
    static Map<String, int[]> getPartitionInfo(Map<String, Long> map, int i) {
        PriorityQueue priorityQueue = new PriorityQueue((tuple2, tuple22) -> {
            return Long.compare(((Long) tuple2.f1).longValue(), ((Long) tuple22.f1).longValue()) == 0 ? Integer.compare(((Integer) tuple2.f0).intValue(), ((Integer) tuple22.f0).intValue()) : Long.compare(((Long) tuple2.f1).longValue(), ((Long) tuple22.f1).longValue());
        });
        for (int i2 = 0; i2 < i; i2++) {
            priorityQueue.add(Tuple2.of(Integer.valueOf(i2), 0L));
        }
        Tuple2[] tuple2Arr = new Tuple2[map.size()];
        int i3 = 0;
        for (Map.Entry<String, Long> entry : map.entrySet()) {
            tuple2Arr[i3] = Tuple2.of(entry.getKey(), entry.getValue());
            i3++;
        }
        Arrays.sort(tuple2Arr, (tuple23, tuple24) -> {
            return (-Long.compare(((Long) tuple23.f1).longValue(), ((Long) tuple24.f1).longValue())) == 0 ? ((String) tuple23.f0).compareTo((String) tuple24.f0) : -Long.compare(((Long) tuple23.f1).longValue(), ((Long) tuple24.f1).longValue());
        });
        HashMap hashMap = new HashMap(tuple2Arr.length);
        long j = 0;
        for (Tuple2 tuple25 : tuple2Arr) {
            hashMap.put(tuple25.f0, new ArrayList());
            j += ((Long) tuple25.f1).longValue();
        }
        long j2 = (j / i) + 1;
        for (Tuple2 tuple26 : tuple2Arr) {
            String str = (String) tuple26.f0;
            long longValue = ((Long) tuple26.f1).longValue();
            long j3 = longValue / j2;
            if (j3 == 0) {
                j3 = 1;
            }
            long j4 = longValue / j3;
            for (int i4 = 0; i4 < j3; i4++) {
                Tuple2 tuple27 = (Tuple2) priorityQueue.remove();
                ((List) hashMap.get(str)).add(tuple27.f0);
                priorityQueue.add(Tuple2.of(tuple27.f0, Long.valueOf(j4 + ((Long) tuple27.f1).longValue())));
            }
        }
        HashMap hashMap2 = new HashMap();
        for (Map.Entry entry2 : hashMap.entrySet()) {
            int[] array = ((List) entry2.getValue()).stream().mapToInt((v0) -> {
                return v0.intValue();
            }).toArray();
            Arrays.sort(array);
            hashMap2.put(entry2.getKey(), array);
        }
        return hashMap2;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ GroupKMeansBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
