package com.alibaba.alink.operator.batch.clustering;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.ReservedColsWithFirstInputSpec;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.RowUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.clustering.dbscan.Dbscan;
import com.alibaba.alink.operator.common.clustering.dbscan.DbscanConstant;
import com.alibaba.alink.operator.common.clustering.dbscan.DbscanNewSample;
import com.alibaba.alink.operator.common.clustering.dbscan.Type;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.operator.common.distance.FastDistanceData;
import com.alibaba.alink.operator.common.distance.HaversineDistance;
import com.alibaba.alink.params.clustering.GroupGeoDbscanParams;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "idCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "groupCols", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "latitudeCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "longitudeCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES})})
@NameCn("分组经纬度Dbscan")
@ReservedColsWithFirstInputSpec
@NameEn("Group Geo Dbscan")
/* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupGeoDbscanBatchOp.class */
public class GroupGeoDbscanBatchOp extends BatchOperator<GroupGeoDbscanBatchOp> implements GroupGeoDbscanParams<GroupGeoDbscanBatchOp> {
    private static final long serialVersionUID = -1650606375272968610L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupGeoDbscanBatchOp$Clustering.class */
    public static class Clustering extends RichGroupReduceFunction<DbscanNewSample, DbscanNewSample> {
        private static final long serialVersionUID = 3474119012459738732L;
        private double epsilon;
        private int minPoints;
        private FastDistance baseDistance;
        private int groupMaxSamples;
        private boolean skip;

        public Clustering(double d, int i, FastDistance fastDistance, int i2, boolean z) {
            this.epsilon = d;
            this.minPoints = i;
            this.baseDistance = fastDistance;
            this.groupMaxSamples = i2;
            this.skip = z;
        }

        public void reduce(Iterable<DbscanNewSample> iterable, Collector<DbscanNewSample> collector) throws Exception {
            int i = 0;
            ArrayList<DbscanNewSample> arrayList = new ArrayList();
            Iterator<DbscanNewSample> it = iterable.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next());
            }
            ArrayList<DbscanNewSample> arrayList2 = null;
            if (arrayList.size() >= this.groupMaxSamples) {
                if (this.skip) {
                    return;
                }
                Collections.shuffle(arrayList);
                ArrayList arrayList3 = new ArrayList(this.groupMaxSamples);
                arrayList2 = new ArrayList();
                for (int i2 = 0; i2 < arrayList.size(); i2++) {
                    if (i2 < this.groupMaxSamples) {
                        arrayList3.add(arrayList.get(i2));
                    } else {
                        arrayList2.add(arrayList.get(i2));
                    }
                }
                arrayList = arrayList3;
            }
            for (DbscanNewSample dbscanNewSample : arrayList) {
                if (dbscanNewSample.getClusterId() == -1 && Dbscan.expandCluster(arrayList, dbscanNewSample, i, this.epsilon, this.minPoints, this.baseDistance)) {
                    i++;
                }
            }
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                collector.collect((DbscanNewSample) it2.next());
            }
            if (null != arrayList2) {
                for (DbscanNewSample dbscanNewSample2 : arrayList2) {
                    double d = Double.POSITIVE_INFINITY;
                    for (DbscanNewSample dbscanNewSample3 : arrayList) {
                        if (dbscanNewSample3.getType().equals(Type.CORE)) {
                            double d2 = this.baseDistance.calc((FastDistanceData) dbscanNewSample3.getVec(), (FastDistanceData) dbscanNewSample2.getVec()).get(0, 0);
                            if (d2 < d) {
                                dbscanNewSample2.setClusterId(dbscanNewSample3.getClusterId());
                                d = d2;
                            }
                        }
                    }
                    if (d > this.epsilon) {
                        dbscanNewSample2.setType(Type.NOISE);
                        dbscanNewSample2.setClusterId(-2147483648L);
                    } else {
                        dbscanNewSample2.setType(Type.LINKED);
                    }
                    collector.collect(dbscanNewSample2);
                }
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupGeoDbscanBatchOp$MapToRow.class */
    public static class MapToRow extends RichMapFunction<DbscanNewSample, Row> {
        private static final long serialVersionUID = 5024255660037882136L;

        public Row map(DbscanNewSample dbscanNewSample) throws Exception {
            return RowUtil.merge(Row.of(new Object[]{dbscanNewSample.getType().name(), Long.valueOf(dbscanNewSample.getClusterId())}), dbscanNewSample.getVec().getRows()[0]);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupGeoDbscanBatchOp$SelectGroup.class */
    public static class SelectGroup implements KeySelector<DbscanNewSample, Integer> {
        private static final long serialVersionUID = 6268163376874147254L;

        public Integer getKey(DbscanNewSample dbscanNewSample) {
            return Integer.valueOf(dbscanNewSample.getGroupHashKey());
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupGeoDbscanBatchOp$WeightPartitioner.class */
    public static class WeightPartitioner implements Partitioner<Integer> {
        private static final long serialVersionUID = -4197634749052990621L;

        public int partition(Integer num, int i) {
            return Math.abs(num.intValue()) % i;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupGeoDbscanBatchOp$mapToDataVectorSample.class */
    public static class mapToDataVectorSample extends RichMapFunction<Row, DbscanNewSample> {
        private static final long serialVersionUID = -9186022939852072237L;
        private int groupColNamesSize;
        private FastDistance distance;

        public mapToDataVectorSample(int i, FastDistance fastDistance) {
            this.groupColNamesSize = i;
            this.distance = fastDistance;
        }

        public DbscanNewSample map(Row row) throws Exception {
            String[] strArr = new String[this.groupColNamesSize];
            for (int i = 0; i < this.groupColNamesSize; i++) {
                strArr[i] = row.getField(i).toString();
            }
            DenseVector denseVector = new DenseVector(2);
            denseVector.set(0, ((Number) row.getField(this.groupColNamesSize)).doubleValue());
            denseVector.set(1, ((Number) row.getField(this.groupColNamesSize + 1)).doubleValue());
            Row row2 = new Row((row.getArity() - this.groupColNamesSize) - 2);
            for (int i2 = 0; i2 < row2.getArity(); i2++) {
                row2.setField(i2, row.getField(this.groupColNamesSize + 2 + i2));
            }
            return new DbscanNewSample(this.distance.prepareVectorData(Tuple2.of(denseVector, row2)), strArr);
        }
    }

    public GroupGeoDbscanBatchOp() {
        this(null);
    }

    public GroupGeoDbscanBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public GroupGeoDbscanBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String latitudeCol = getLatitudeCol();
        String longitudeCol = getLongitudeCol();
        int intValue = getMinPoints().intValue();
        Double d = (Double) getParams().get(EPSILON);
        String idCol = getIdCol();
        String predictionCol = getPredictionCol();
        int intValue2 = getGroupMaxSamples().intValue();
        boolean booleanValue = getSkip().booleanValue();
        String[] reservedCols = getReservedCols();
        HaversineDistance haversineDistance = new HaversineDistance();
        String[] groupCols = getGroupCols();
        if (null == idCol || "".equals(idCol)) {
            throw new RuntimeException("idCol column should be set!");
        }
        if (TableUtil.findColIndex(groupCols, idCol) >= 0) {
            throw new RuntimeException("idCol column should NOT be included in groupColNames !");
        }
        String[] strArr = null == reservedCols ? new String[]{DbscanConstant.TYPE, predictionCol} : (String[]) ArrayUtils.addAll(new String[]{DbscanConstant.TYPE, predictionCol}, reservedCols);
        TypeInformation[] typeInformationArr = null == reservedCols ? new TypeInformation[]{Types.STRING, Types.LONG} : (TypeInformation[]) ArrayUtils.addAll(new TypeInformation[]{Types.STRING, Types.LONG}, TableUtil.findColTypesWithAssertAndHint(checkAndGetFirst.getSchema(), reservedCols));
        String[] strArr2 = (String[]) ArrayUtils.addAll(groupCols, new String[]{latitudeCol, longitudeCol});
        setOutput((DataSet<Row>) checkAndGetFirst.select(null == reservedCols ? strArr2 : (String[]) ArrayUtils.addAll(strArr2, reservedCols)).getDataSet().map(new mapToDataVectorSample(groupCols.length, haversineDistance)).groupBy(new SelectGroup()).withPartitioner(new WeightPartitioner()).reduceGroup(new Clustering(d.doubleValue(), intValue, haversineDistance, intValue2, booleanValue)).map(new MapToRow()), new TableSchema(strArr, typeInformationArr));
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ GroupGeoDbscanBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
