package com.alibaba.alink.operator.batch.clustering;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.ReservedColsWithSecondInputSpec;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.similarity.VectorNearestNeighborPredictBatchOp;
import com.alibaba.alink.operator.batch.similarity.VectorNearestNeighborTrainBatchOp;
import com.alibaba.alink.operator.batch.source.DataSetWrapperBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp;
import com.alibaba.alink.operator.common.clustering.dbscan.Dbscan;
import com.alibaba.alink.operator.common.clustering.dbscan.DbscanConstant;
import com.alibaba.alink.operator.common.clustering.dbscan.DbscanModelDataConverter;
import com.alibaba.alink.operator.common.clustering.dbscan.DbscanModelInfoBatchOp;
import com.alibaba.alink.operator.common.clustering.dbscan.DbscanModelTrainData;
import com.alibaba.alink.operator.common.clustering.dbscan.LocalCluster;
import com.alibaba.alink.operator.common.clustering.dbscan.Type;
import com.alibaba.alink.operator.common.similarity.NearestNeighborsMapper;
import com.alibaba.alink.params.clustering.DbscanParams;
import com.alibaba.alink.params.shared.clustering.HasClusteringDistanceType;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.operators.base.JoinOperatorBase;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.aggregation.Aggregations;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.JoinOperator;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.MapPartitionOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT), @PortSpec(value = PortType.MODEL, desc = PortDesc.DBSCAN_MODEL)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "vectorCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}, allowedTypeCollections = {TypeCollections.VECTOR_TYPES}), @ParamSelectColumnSpec(name = "idCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR})})
@ReservedColsWithSecondInputSpec
@NameCn("DBSCAN")
@NameEn("DBSCAN")
/* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/DbscanBatchOp.class */
public final class DbscanBatchOp extends BatchOperator<DbscanBatchOp> implements DbscanParams<DbscanBatchOp>, WithModelInfoBatchOp<DbscanModelInfoBatchOp.DbscanModelInfo, DbscanBatchOp, DbscanModelInfoBatchOp> {
    private static final long serialVersionUID = 2680984884814078788L;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/DbscanBatchOp$AssignAllClusterId.class */
    public static class AssignAllClusterId implements JoinFunction<Tuple3<Integer, Type, int[]>, Tuple2<Integer, Integer>, Tuple3<Integer, Type, Integer>> {
        private static final long serialVersionUID = -3817364693537340163L;

        AssignAllClusterId() {
        }

        public Tuple3<Integer, Type, Integer> join(Tuple3<Integer, Type, int[]> tuple3, Tuple2<Integer, Integer> tuple2) {
            if (null == tuple2) {
                return Tuple3.of(tuple3.f0, tuple3.f1, Integer.valueOf(Dbscan.NOISE));
            }
            if (((Type) tuple3.f1).equals(Type.NOISE)) {
                tuple3.f1 = Type.LINKED;
            }
            return Tuple3.of(tuple3.f0, tuple3.f1, tuple2.f1);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/DbscanBatchOp$AssignContinuousClusterId.class */
    public static class AssignContinuousClusterId extends RichMapPartitionFunction<Tuple2<Integer, LocalCluster>, Tuple2<Integer, Integer>> {
        private static final long serialVersionUID = -503944144706407009L;

        public void mapPartition(Iterable<Tuple2<Integer, LocalCluster>> iterable, Collector<Tuple2<Integer, Integer>> collector) throws Exception {
            if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                Iterator<Tuple2<Integer, LocalCluster>> it = iterable.iterator();
                if (it.hasNext()) {
                    Tuple2<Integer, LocalCluster> next = it.next();
                    int[] keys = ((LocalCluster) next.f1).getKeys();
                    int[] clusterIds = ((LocalCluster) next.f1).getClusterIds();
                    HashMap hashMap = new HashMap();
                    int i = 0;
                    for (int i2 = 0; i2 < keys.length; i2++) {
                        Integer num = (Integer) hashMap.get(Integer.valueOf(clusterIds[i2]));
                        if (null == num) {
                            int i3 = i;
                            i++;
                            num = Integer.valueOf(i3);
                            hashMap.put(Integer.valueOf(clusterIds[i2]), num);
                        }
                        collector.collect(Tuple2.of(Integer.valueOf(keys[i2]), num));
                    }
                }
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/DbscanBatchOp$GetCorePoints.class */
    public static class GetCorePoints implements MapPartitionFunction<Row, Tuple3<Integer, Type, int[]>> {
        private int minPoints;

        public GetCorePoints(int i) {
            this.minPoints = i;
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Tuple3<Integer, Type, int[]>> collector) {
            for (Row row : iterable) {
                Integer valueOf = Integer.valueOf(((Integer) row.getField(0)).intValue());
                Tuple2<List<Object>, List<Object>> extractKObject = NearestNeighborsMapper.extractKObject((String) row.getField(1), Integer.class);
                if (null == extractKObject.f0 || ((List) extractKObject.f0).size() < this.minPoints) {
                    collector.collect(Tuple3.of(valueOf, Type.NOISE, new int[0]));
                } else {
                    int[] iArr = new int[((List) extractKObject.f0).size()];
                    for (int i = 0; i < ((List) extractKObject.f0).size(); i++) {
                        iArr[i] = ((Integer) ((List) extractKObject.f0).get(i)).intValue();
                    }
                    Arrays.sort(iArr);
                    collector.collect(Tuple3.of(valueOf, Type.CORE, iArr));
                }
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/DbscanBatchOp$LocalMerge.class */
    public static class LocalMerge extends RichMapPartitionFunction<Tuple2<Integer, LocalCluster>, Tuple3<Integer, LocalCluster, Boolean>> {
        private static final long serialVersionUID = -7265351591038567537L;
        private TreeMap<Integer, Integer> global;

        /* JADX WARN: Multi-variable type inference failed */
        public void open(Configuration configuration) {
            List<Tuple2> broadcastVariable = getRuntimeContext().getBroadcastVariable("global");
            this.global = new TreeMap<>();
            for (Tuple2 tuple2 : broadcastVariable) {
                this.global.put(tuple2.f0, tuple2.f1);
            }
            DbscanBatchOp.reduceTreeMap(this.global);
        }

        public void mapPartition(Iterable<Tuple2<Integer, LocalCluster>> iterable, Collector<Tuple3<Integer, LocalCluster, Boolean>> collector) {
            boolean z = true;
            Integer num = null;
            for (Tuple2<Integer, LocalCluster> tuple2 : iterable) {
                if (null == num) {
                    num = (Integer) tuple2.f0;
                }
                z = DbscanBatchOp.updateTreeMap(this.global, (LocalCluster) tuple2.f1);
            }
            if (null != num) {
                collector.collect(Tuple3.of(num, DbscanBatchOp.treeMapToLocalCluster(this.global), Boolean.valueOf(z)));
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/DbscanBatchOp$ReduceLocal.class */
    public static class ReduceLocal extends RichMapPartitionFunction<Tuple3<Integer, Type, int[]>, Tuple2<Integer, LocalCluster>> {
        private static final long serialVersionUID = -6516417460468964305L;

        public void mapPartition(Iterable<Tuple3<Integer, Type, int[]>> iterable, Collector<Tuple2<Integer, LocalCluster>> collector) {
            TreeMap treeMap = new TreeMap();
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            for (Tuple3<Integer, Type, int[]> tuple3 : iterable) {
                Preconditions.checkArgument(((Type) tuple3.f1).equals(Type.NOISE) ^ (((int[]) tuple3.f2).length > 0), "Noise must be empty!");
                if (((int[]) tuple3.f2).length > 0) {
                    DbscanBatchOp.updateTreeMap((TreeMap<Integer, Integer>) treeMap, (int[]) tuple3.f2);
                }
            }
            collector.collect(Tuple2.of(Integer.valueOf(indexOfThisSubtask), DbscanBatchOp.treeMapToLocalCluster(treeMap)));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/DbscanBatchOp$SaveModel.class */
    public static class SaveModel implements MapPartitionFunction<Tuple2<Vector, Long>, Row> {
        private static final long serialVersionUID = 7638276873515252678L;
        private String vectorColName;
        private double epsilon;
        private HasClusteringDistanceType.DistanceType distanceType;

        public SaveModel(String str, double d, HasClusteringDistanceType.DistanceType distanceType) {
            this.vectorColName = str;
            this.epsilon = d;
            this.distanceType = distanceType;
        }

        public void mapPartition(Iterable<Tuple2<Vector, Long>> iterable, Collector<Row> collector) throws Exception {
            DbscanModelTrainData dbscanModelTrainData = new DbscanModelTrainData();
            dbscanModelTrainData.coreObjects = iterable;
            dbscanModelTrainData.vectorColName = this.vectorColName;
            dbscanModelTrainData.epsilon = this.epsilon;
            dbscanModelTrainData.distanceType = this.distanceType;
            new DbscanModelDataConverter().save(dbscanModelTrainData, collector);
        }
    }

    public DbscanBatchOp() {
        super(new Params());
    }

    public DbscanBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public DbscanBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String idCol = getIdCol();
        String vectorCol = getVectorCol();
        MapOperator map = DataSetUtils.zipWithIndex(checkAndGetFirst.select(new String[]{idCol, vectorCol}).getDataSet()).map(new RichMapFunction<Tuple2<Long, Row>, Tuple3<Integer, Object, Vector>>() { // from class: com.alibaba.alink.operator.batch.clustering.DbscanBatchOp.1
            private static final long serialVersionUID = -4516567863938069544L;

            public Tuple3<Integer, Object, Vector> map(Tuple2<Long, Row> tuple2) {
                return Tuple3.of(Integer.valueOf(((Long) tuple2.f0).intValue()), ((Row) tuple2.f1).getField(0), VectorUtil.getVector(((Row) tuple2.f1).getField(1)));
            }
        });
        DataSetWrapperBatchOp dataSetWrapperBatchOp = new DataSetWrapperBatchOp(map.map(new MapFunction<Tuple3<Integer, Object, Vector>, Row>() { // from class: com.alibaba.alink.operator.batch.clustering.DbscanBatchOp.2
            private static final long serialVersionUID = 672726382584730805L;

            public Row map(Tuple3<Integer, Object, Vector> tuple3) {
                return Row.of(new Object[]{tuple3.f0, tuple3.f2});
            }
        }), new String[]{"alink_unique_id", "vector"}, new TypeInformation[]{AlinkTypes.INT, AlinkTypes.VECTOR});
        MapPartitionOperator mapPartition = new VectorNearestNeighborPredictBatchOp().setSelectedCol("vector").setRadius(getEpsilon()).setReservedCols("alink_unique_id").linkFrom(new VectorNearestNeighborTrainBatchOp().setIdCol("alink_unique_id").setSelectedCol("vector").setMetric(getDistanceType().name()).linkFrom(dataSetWrapperBatchOp), dataSetWrapperBatchOp).select(new String[]{"alink_unique_id", "vector"}).getDataSet().mapPartition(new GetCorePoints(getMinPoints().intValue()));
        MapPartitionOperator mapPartition2 = mapPartition.mapPartition(new ReduceLocal());
        IterativeDataSet iterate = mapPartition2.iterate(Integer.MAX_VALUE);
        SingleInputUdfOperator withBroadcastSet = mapPartition2.mapPartition(new LocalMerge()).withBroadcastSet(iterate.flatMap(new FlatMapFunction<Tuple2<Integer, LocalCluster>, Tuple2<Integer, Integer>>() { // from class: com.alibaba.alink.operator.batch.clustering.DbscanBatchOp.3
            private static final long serialVersionUID = -4049782728006537532L;

            public void flatMap(Tuple2<Integer, LocalCluster> tuple2, Collector<Tuple2<Integer, Integer>> collector) {
                int[] keys = ((LocalCluster) tuple2.f1).getKeys();
                int[] clusterIds = ((LocalCluster) tuple2.f1).getClusterIds();
                for (int i = 0; i < keys.length; i++) {
                    collector.collect(Tuple2.of(Integer.valueOf(keys[i]), Integer.valueOf(clusterIds[i])));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple2<Integer, LocalCluster>) obj, (Collector<Tuple2<Integer, Integer>>) collector);
            }
        }).groupBy(new int[]{0}).aggregate(Aggregations.MAX, 1), "global");
        JoinOperator with = mapPartition.leftOuterJoin(iterate.closeWith(withBroadcastSet.project(new int[]{0, 1}), withBroadcastSet.filter(new FilterFunction<Tuple3<Integer, LocalCluster, Boolean>>() { // from class: com.alibaba.alink.operator.batch.clustering.DbscanBatchOp.4
            private static final long serialVersionUID = -1489238776369734510L;

            public boolean filter(Tuple3<Integer, LocalCluster, Boolean> tuple3) {
                return !((Boolean) tuple3.f2).booleanValue();
            }
        })).mapPartition(new AssignContinuousClusterId()), JoinOperatorBase.JoinHint.BROADCAST_HASH_SECOND).where(new int[]{0}).equalTo(new int[]{0}).with(new AssignAllClusterId());
        JoinOperator.EquiJoin with2 = map.join(with).where(new int[]{0}).equalTo(new int[]{0}).with(new JoinFunction<Tuple3<Integer, Object, Vector>, Tuple3<Integer, Type, Integer>, Row>() { // from class: com.alibaba.alink.operator.batch.clustering.DbscanBatchOp.5
            private static final long serialVersionUID = 7638483527530324994L;

            public Row join(Tuple3<Integer, Object, Vector> tuple3, Tuple3<Integer, Type, Integer> tuple32) {
                return Row.of(new Object[]{tuple3.f1, ((Type) tuple32.f1).name(), Long.valueOf(((Integer) tuple32.f2).intValue())});
            }
        });
        Operator parallelism = with.flatMap(new FlatMapFunction<Tuple3<Integer, Type, Integer>, Tuple2<Integer, Long>>() { // from class: com.alibaba.alink.operator.batch.clustering.DbscanBatchOp.6
            private static final long serialVersionUID = -4449631564554949600L;

            public void flatMap(Tuple3<Integer, Type, Integer> tuple3, Collector<Tuple2<Integer, Long>> collector) {
                if (((Type) tuple3.f1).equals(Type.CORE)) {
                    collector.collect(Tuple2.of(tuple3.f0, Long.valueOf(((Integer) tuple3.f2).intValue())));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple3<Integer, Type, Integer>) obj, (Collector<Tuple2<Integer, Long>>) collector);
            }
        }).join(map).where(new int[]{0}).equalTo(new int[]{0}).with(new JoinFunction<Tuple2<Integer, Long>, Tuple3<Integer, Object, Vector>, Tuple2<Vector, Long>>() { // from class: com.alibaba.alink.operator.batch.clustering.DbscanBatchOp.7
            private static final long serialVersionUID = -1388744253754875541L;

            public Tuple2<Vector, Long> join(Tuple2<Integer, Long> tuple2, Tuple3<Integer, Object, Vector> tuple3) {
                return Tuple2.of(tuple3.f2, tuple2.f1);
            }
        }).mapPartition(new SaveModel(vectorCol, getEpsilon().doubleValue(), getDistanceType())).setParallelism(1);
        setOutput(with2, new String[]{idCol, DbscanConstant.TYPE, getPredictionCol()}, new TypeInformation[]{checkAndGetFirst.getColTypes()[TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getColNames(), idCol)], AlinkTypes.STRING, AlinkTypes.LONG});
        setSideOutputTables(new Table[]{DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) parallelism, new DbscanModelDataConverter().getModelSchema())});
        return this;
    }

    public static void updateTreeMap(TreeMap<Integer, Integer> treeMap, int[] iArr) {
        int i = iArr[iArr.length - 1];
        for (int i2 : iArr) {
            i = Math.max(((Integer) treeMap.getOrDefault(Integer.valueOf(i2), Integer.valueOf(i))).intValue(), i);
        }
        for (int i3 : iArr) {
            Integer num = treeMap.get(Integer.valueOf(i3));
            if (null == num) {
                treeMap.put(Integer.valueOf(i3), Integer.valueOf(i));
            } else if (i > num.intValue()) {
                treeMap.put(treeMap.get(num), Integer.valueOf(i));
            }
        }
        reduceTreeMap(treeMap);
    }

    public static boolean updateTreeMap(TreeMap<Integer, Integer> treeMap, LocalCluster localCluster) {
        int[] keys = localCluster.getKeys();
        int[] clusterIds = localCluster.getClusterIds();
        boolean z = true;
        for (int i = 0; i < keys.length; i++) {
            int i2 = clusterIds[i];
            if (treeMap.get(Integer.valueOf(keys[i])).intValue() > treeMap.get(Integer.valueOf(i2)).intValue()) {
                z = false;
                treeMap.put(Integer.valueOf(i2), treeMap.get(Integer.valueOf(keys[i])));
            }
        }
        reduceTreeMap(treeMap);
        return z;
    }

    public static void reduceTreeMap(TreeMap<Integer, Integer> treeMap) {
        Iterator<Integer> it = treeMap.descendingKeySet().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            int intValue2 = treeMap.get(Integer.valueOf(intValue)).intValue();
            int intValue3 = treeMap.get(Integer.valueOf(intValue2)).intValue();
            if (intValue3 > intValue2) {
                treeMap.put(Integer.valueOf(intValue), Integer.valueOf(intValue3));
            }
        }
    }

    public static LocalCluster treeMapToLocalCluster(TreeMap<Integer, Integer> treeMap) {
        int[] iArr = new int[treeMap.size()];
        int[] iArr2 = new int[treeMap.size()];
        int i = 0;
        for (Map.Entry<Integer, Integer> entry : treeMap.entrySet()) {
            iArr2[i] = entry.getKey().intValue();
            int i2 = i;
            i++;
            iArr[i2] = entry.getValue().intValue();
        }
        return new LocalCluster(iArr2, iArr);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp
    public DbscanModelInfoBatchOp getModelInfoBatchOp() {
        return new DbscanModelInfoBatchOp(getParams()).linkFrom(this);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ DbscanBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
