package com.alibaba.alink.operator.common.clustering.kmeans;

import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.utils.Functional;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.operator.common.distance.FastDistanceData;
import com.alibaba.alink.operator.common.distance.FastDistanceMatrixData;
import com.alibaba.alink.operator.common.distance.FastDistanceVectorData;
import com.alibaba.alink.params.clustering.KMeansTrainParams;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.TreeMap;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichFilterFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.aggregation.Aggregations;
import org.apache.flink.api.java.operators.FilterOperator;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple5;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.shaded.guava18.com.google.common.hash.HashFunction;
import org.apache.flink.shaded.guava18.com.google.common.hash.Hashing;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/kmeans/KMeansInitCentroids.class */
public class KMeansInitCentroids implements Serializable {
    private static final Logger LOG = LoggerFactory.getLogger(KMeansInitCentroids.class);
    private static final String CENTER = "centers";
    private static final String SUM_COSTS = "sumCosts";
    private static final String VECTOR_SIZE = "vectorSize";
    private static final long serialVersionUID = -8219073698493308787L;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/clustering/kmeans/KMeansInitCentroids$CalWeight.class */
    public static class CalWeight extends RichMapFunction<Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean>, Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean>> {
        private static final long serialVersionUID = 4828540656733702022L;
        private transient List<Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean>> centers;
        private FastDistance distance;

        CalWeight(FastDistance fastDistance) {
            this.distance = fastDistance;
        }

        public void open(Configuration configuration) throws Exception {
            this.centers = getRuntimeContext().getBroadcastVariable(KMeansInitCentroids.CENTER);
        }

        public Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean> map(Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean> tuple5) throws Exception {
            for (Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean> tuple52 : this.centers) {
                if (((Long) tuple52.f0).equals(tuple5.f0)) {
                    tuple5.f4 = true;
                } else {
                    double d = this.distance.calc((FastDistanceData) tuple52.f1, (FastDistanceData) tuple5.f1).get(0, 0);
                    if (d < ((Double) tuple5.f3).doubleValue()) {
                        tuple5.f2 = tuple52.f0;
                        tuple5.f3 = Double.valueOf(d);
                    }
                }
            }
            return tuple5;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/clustering/kmeans/KMeansInitCentroids$FilterCenter.class */
    public static class FilterCenter implements FilterFunction<Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean>> {
        private static final long serialVersionUID = -8363415544217000362L;

        private FilterCenter() {
        }

        public boolean filter(Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean> tuple5) throws Exception {
            return ((Boolean) tuple5.f4).booleanValue();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/clustering/kmeans/KMeansInitCentroids$FilterData.class */
    public static class FilterData implements FilterFunction<Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean>> {
        private static final long serialVersionUID = -7062845155572458129L;

        private FilterData() {
        }

        public boolean filter(Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean> tuple5) throws Exception {
            return !((Boolean) tuple5.f4).booleanValue();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/clustering/kmeans/KMeansInitCentroids$FilterNewCenter.class */
    public static class FilterNewCenter extends RichFilterFunction<Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean>> {
        private static final long serialVersionUID = -6433641767140518820L;
        private transient double costThre;
        private transient Random random;
        private int k;
        private int seed;

        FilterNewCenter(int i, int i2) {
            this.k = i;
            this.seed = i2;
        }

        public void open(Configuration configuration) {
            this.random = new Random(this.seed);
            this.costThre = (2.0d * this.k) / ((Double) ((Tuple1) getRuntimeContext().getBroadcastVariable(KMeansInitCentroids.SUM_COSTS).get(0)).f0).doubleValue();
        }

        public boolean filter(Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean> tuple5) {
            return this.random.nextDouble() < ((Double) tuple5.f3).doubleValue() * this.costThre;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/clustering/kmeans/KMeansInitCentroids$LocalKmeans.class */
    public static class LocalKmeans extends RichMapPartitionFunction<Tuple2<Long, FastDistanceVectorData>, FastDistanceMatrixData> {
        private static final long serialVersionUID = 3014142447237244585L;
        private FastDistance distance;
        private int k;
        private transient int vectorSize;
        private int seed;

        LocalKmeans(int i, FastDistance fastDistance, int i2) {
            this.k = i;
            this.distance = fastDistance;
            this.seed = i2;
        }

        public void open(Configuration configuration) {
            KMeansInitCentroids.LOG.info("TaskId {} Local Kmeans begins!", Integer.valueOf(getRuntimeContext().getIndexOfThisSubtask()));
            this.vectorSize = ((Integer) getRuntimeContext().getBroadcastVariable("vectorSize").get(0)).intValue();
        }

        public void close() {
            KMeansInitCentroids.LOG.info("TaskId {} Local Kmeans ends!", Integer.valueOf(getRuntimeContext().getIndexOfThisSubtask()));
        }

        public void mapPartition(Iterable<Tuple2<Long, FastDistanceVectorData>> iterable, Collector<FastDistanceMatrixData> collector) throws Exception {
            ArrayList arrayList = new ArrayList();
            arrayList.getClass();
            iterable.forEach((v1) -> {
                r1.add(v1);
            });
            arrayList.sort(Comparator.comparingLong(tuple2 -> {
                return ((Long) tuple2.f0).longValue();
            }));
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            arrayList.forEach(tuple22 -> {
                arrayList2.add(tuple22.f0);
                arrayList3.add(tuple22.f1);
            });
            if (arrayList3.size() <= this.k) {
                collector.collect(KMeansUtil.buildCentroidsMatrix(arrayList3, this.distance, this.vectorSize));
                return;
            }
            long[] jArr = new long[arrayList2.size()];
            for (int i = 0; i < jArr.length; i++) {
                jArr[i] = ((Long) arrayList2.get(i)).longValue();
            }
            collector.collect(LocalKmeansFunc.kmeans(this.k, jArr, (FastDistanceVectorData[]) arrayList3.toArray(new FastDistanceVectorData[0]), this.distance, this.vectorSize, this.seed));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/clustering/kmeans/KMeansInitCentroids$TransformToCenter.class */
    public static class TransformToCenter implements MapFunction<Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean>, Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean>> {
        private static final long serialVersionUID = 5589065815045593976L;

        private TransformToCenter() {
        }

        public Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean> map(Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean> tuple5) throws Exception {
            tuple5.f2 = -1L;
            tuple5.f3 = Double.valueOf(Double.MAX_VALUE);
            tuple5.f4 = true;
            return tuple5;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/clustering/kmeans/KMeansInitCentroids$TreeMapT.class */
    public static class TreeMapT<T extends Serializable> implements Serializable {
        private static final long serialVersionUID = -2350624942257559150L;
        public TreeMap<Long, T> treeMap = new TreeMap<>();
    }

    public static DataSet<FastDistanceMatrixData> initKmeansCentroids(DataSet<FastDistanceVectorData> dataSet, FastDistance fastDistance, Params params, DataSet<Integer> dataSet2, int i) {
        DataSet<FastDistanceMatrixData> kMeansPlusPlusInit;
        InitMode initMode = (InitMode) params.get(KMeansTrainParams.INIT_MODE);
        int intValue = ((Integer) params.get(KMeansTrainParams.INIT_STEPS)).intValue();
        int intValue2 = ((Integer) params.get(KMeansTrainParams.K)).intValue();
        switch (initMode) {
            case RANDOM:
                kMeansPlusPlusInit = randomInit(dataSet, intValue2, fastDistance, dataSet2, i);
                break;
            case K_MEANS_PARALLEL:
                kMeansPlusPlusInit = kMeansPlusPlusInit(dataSet, intValue2, intValue, fastDistance, dataSet2, i);
                break;
            default:
                throw new AkIllegalArgumentException("Unknown init mode: " + initMode);
        }
        return kMeansPlusPlusInit;
    }

    private static DataSet<FastDistanceMatrixData> randomInit(DataSet<FastDistanceVectorData> dataSet, int i, final FastDistance fastDistance, DataSet<Integer> dataSet2, int i2) {
        return selectTopK(i, i2, dataSet, new Functional.SerializableFunction<FastDistanceVectorData, byte[]>() { // from class: com.alibaba.alink.operator.common.clustering.kmeans.KMeansInitCentroids.2
            private static final long serialVersionUID = 6092460932245165972L;

            @Override // java.util.function.Function
            public byte[] apply(FastDistanceVectorData fastDistanceVectorData) {
                return fastDistanceVectorData.getVector().toBytes();
            }
        }).mapPartition(new RichMapPartitionFunction<Tuple2<Long, FastDistanceVectorData>, FastDistanceMatrixData>() { // from class: com.alibaba.alink.operator.common.clustering.kmeans.KMeansInitCentroids.1
            private static final long serialVersionUID = 2012759243672199273L;

            public void mapPartition(Iterable<Tuple2<Long, FastDistanceVectorData>> iterable, Collector<FastDistanceMatrixData> collector) {
                int intValue = ((Integer) getRuntimeContext().getBroadcastVariable("vectorSize").get(0)).intValue();
                ArrayList arrayList = new ArrayList();
                iterable.forEach(tuple2 -> {
                    arrayList.add(tuple2.f1);
                });
                collector.collect(KMeansUtil.buildCentroidsMatrix(arrayList, FastDistance.this, intValue));
            }
        }).withBroadcastSet(dataSet2, "vectorSize").setParallelism(1);
    }

    private static DataSet<FastDistanceMatrixData> kMeansPlusPlusInit(DataSet<FastDistanceVectorData> dataSet, int i, int i2, FastDistance fastDistance, DataSet<Integer> dataSet2, int i3) {
        final HashFunction murmur3_128 = Hashing.murmur3_128(i3);
        SingleInputUdfOperator withForwardedFields = dataSet.map(new MapFunction<FastDistanceVectorData, Tuple2<Long, FastDistanceVectorData>>() { // from class: com.alibaba.alink.operator.common.clustering.kmeans.KMeansInitCentroids.3
            private static final long serialVersionUID = 1539229008777267709L;

            public Tuple2<Long, FastDistanceVectorData> map(FastDistanceVectorData fastDistanceVectorData) throws Exception {
                return Tuple2.of(Long.valueOf(murmur3_128.hashUnencodedChars(fastDistanceVectorData.toString()).asLong()), fastDistanceVectorData);
            }
        }).map(new MapFunction<Tuple2<Long, FastDistanceVectorData>, Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean>>() { // from class: com.alibaba.alink.operator.common.clustering.kmeans.KMeansInitCentroids.4
            private static final long serialVersionUID = -8289894247468770813L;

            public Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean> map(Tuple2<Long, FastDistanceVectorData> tuple2) {
                return Tuple5.of(tuple2.f0, tuple2.f1, -1L, Double.valueOf(Double.MAX_VALUE), false);
            }
        }).withForwardedFields(new String[]{"f0;f1"});
        IterativeDataSet iterate = withForwardedFields.map(new CalWeight(fastDistance)).withBroadcastSet(withForwardedFields.maxBy(new int[]{0}).map(new TransformToCenter()).withForwardedFields(new String[]{"f0;f1"}), CENTER).withForwardedFields(new String[]{"f0;f1;f4"}).iterate(i2 - 1);
        FilterOperator filter = iterate.filter(new FilterData());
        DataSet closeWith = iterate.closeWith(filter.map(new CalWeight(fastDistance)).withBroadcastSet(filter.partitionCustom(new Partitioner<Long>() { // from class: com.alibaba.alink.operator.common.clustering.kmeans.KMeansInitCentroids.5
            private static final long serialVersionUID = 8742959167492464159L;

            public int partition(Long l, int i4) {
                return (int) (Math.abs(l.longValue()) % i4);
            }
        }, 0).sortPartition(0, Order.DESCENDING).filter(new FilterNewCenter(i, i3)).withBroadcastSet(filter.project(new int[]{3}).aggregate(Aggregations.SUM, 0), SUM_COSTS).name("kmeans_||_pick").map(new TransformToCenter()).withForwardedFields(new String[]{"f0;f1"}), CENTER).withForwardedFields(new String[]{"f0;f1;f4"}).union(iterate.filter(new FilterCenter())));
        return closeWith.filter(new FilterData()).map(new MapFunction<Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean>, Tuple2<Long, Long>>() { // from class: com.alibaba.alink.operator.common.clustering.kmeans.KMeansInitCentroids.6
            private static final long serialVersionUID = -7230628651729304469L;

            public Tuple2<Long, Long> map(Tuple5<Long, FastDistanceVectorData, Long, Double, Boolean> tuple5) {
                return Tuple2.of(tuple5.f2, 1L);
            }
        }).withForwardedFields(new String[]{"f2->f0"}).groupBy(new int[]{0}).aggregate(Aggregations.SUM, 1).join(closeWith.filter(new FilterCenter()).project(new int[]{0, 1})).where(new int[]{0}).equalTo(new int[]{0}).projectFirst(new int[]{1}).projectSecond(new int[]{1}).mapPartition(new LocalKmeans(i, fastDistance, i3)).withBroadcastSet(dataSet2, "vectorSize").setParallelism(1);
    }

    public static <T> DataSet<Tuple2<Long, T>> selectTopK(final int i, int i2, DataSet<T> dataSet, final Functional.SerializableFunction<T, byte[]> serializableFunction) {
        TypeInformation type = dataSet.getType();
        final HashFunction murmur3_128 = Hashing.murmur3_128(i2);
        return dataSet.map(new RichMapFunction<T, Tuple2<Long, T>>() { // from class: com.alibaba.alink.operator.common.clustering.kmeans.KMeansInitCentroids.10
            private static final long serialVersionUID = 6994623243686615646L;

            /* JADX WARN: Multi-variable type inference failed */
            public Tuple2<Long, T> map(T t) throws Exception {
                return Tuple2.of(Long.valueOf(murmur3_128.hashBytes((byte[]) serializableFunction.apply(t)).asLong()), t);
            }

            /* JADX WARN: Multi-variable type inference failed */
            /* renamed from: map, reason: collision with other method in class */
            public /* bridge */ /* synthetic */ Object m336map(Object obj) throws Exception {
                return map((AnonymousClass10<T>) obj);
            }
        }).returns(new TupleTypeInfo(new TypeInformation[]{Types.LONG, type})).mapPartition(new MapPartitionFunction<Tuple2<Long, T>, TreeMapT>() { // from class: com.alibaba.alink.operator.common.clustering.kmeans.KMeansInitCentroids.9
            private static final long serialVersionUID = 5813015538018452614L;

            public void mapPartition(Iterable<Tuple2<Long, T>> iterable, Collector<TreeMapT> collector) throws Exception {
                TreeMapT treeMapT = new TreeMapT();
                long j = Long.MAX_VALUE;
                for (Tuple2<Long, T> tuple2 : iterable) {
                    j = KMeansUtil.updateQueue(treeMapT.treeMap, ((Long) tuple2.f0).longValue(), tuple2.f1, i, j);
                }
                collector.collect(treeMapT);
            }
        }).returns(TreeMapT.class).reduce(new ReduceFunction<TreeMapT>() { // from class: com.alibaba.alink.operator.common.clustering.kmeans.KMeansInitCentroids.8
            private static final long serialVersionUID = -3472592245281912784L;

            public TreeMapT reduce(TreeMapT treeMapT, TreeMapT treeMapT2) throws Exception {
                if (treeMapT2.treeMap.size() == 0) {
                    return treeMapT;
                }
                if (treeMapT.treeMap.size() == 0) {
                    return treeMapT2;
                }
                long longValue = ((Long) treeMapT.treeMap.lastEntry().getKey()).longValue();
                for (Map.Entry entry : treeMapT2.treeMap.entrySet()) {
                    longValue = KMeansUtil.updateQueue(treeMapT.treeMap, ((Long) entry.getKey()).longValue(), entry.getValue(), i, longValue);
                }
                return treeMapT;
            }
        }).returns(TreeMapT.class).flatMap(new FlatMapFunction<TreeMapT, Tuple2<Long, T>>() { // from class: com.alibaba.alink.operator.common.clustering.kmeans.KMeansInitCentroids.7
            private static final long serialVersionUID = 3317982387795941044L;

            public void flatMap(TreeMapT treeMapT, Collector<Tuple2<Long, T>> collector) throws Exception {
                long j = 0;
                Iterator it = treeMapT.treeMap.entrySet().iterator();
                while (it.hasNext()) {
                    long j2 = j;
                    j = j2 + 1;
                    collector.collect(Tuple2.of(Long.valueOf(j2), ((Map.Entry) it.next()).getValue()));
                }
            }
        }).returns(new TupleTypeInfo(new TypeInformation[]{Types.LONG, type}));
    }
}
