package com.alibaba.alink.operator.batch.clustering;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.ReservedColsWithSecondInputSpec;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.RowUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.clustering.GroupGeoDbscanBatchOp;
import com.alibaba.alink.operator.common.clustering.DistanceType;
import com.alibaba.alink.operator.common.clustering.dbscan.DbscanCenter;
import com.alibaba.alink.operator.common.clustering.dbscan.DbscanConstant;
import com.alibaba.alink.operator.common.clustering.dbscan.DbscanNewSample;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.params.clustering.GroupDbscanModelParams;
import com.alibaba.alink.params.shared.colname.HasPredictionCol;
import java.util.Iterator;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import scala.util.hashing.MurmurHash3;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(value = PortType.MODEL, desc = PortDesc.MODEL_INFO)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "featureCols", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "groupCols", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR})})
@ReservedColsWithSecondInputSpec
@NameCn("分组Dbscan模型")
@NameEn("Group Dbscan Model")
/* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupDbscanModelBatchOp.class */
public final class GroupDbscanModelBatchOp extends BatchOperator<GroupDbscanModelBatchOp> implements GroupDbscanModelParams<GroupDbscanModelBatchOp> {
    private static final long serialVersionUID = 5788206252024914272L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupDbscanModelBatchOp$MapToRow.class */
    public static class MapToRow implements MapFunction<DbscanCenter<DenseVector>, Row> {
        private static final long serialVersionUID = -5480092592936407825L;
        private int rowArity;
        private int groupColNamesSize;

        public MapToRow(int i, int i2) {
            this.rowArity = i;
            this.groupColNamesSize = i2;
        }

        public Row map(DbscanCenter<DenseVector> dbscanCenter) throws Exception {
            Row row = new Row(this.rowArity - this.groupColNamesSize);
            DenseVector value = dbscanCenter.getValue();
            row.setField(0, Long.valueOf(dbscanCenter.getClusterId()));
            row.setField(1, Long.valueOf(dbscanCenter.getCount()));
            for (int i = 0; i < value.size(); i++) {
                row.setField(i + 2, Double.valueOf(value.get(i)));
            }
            return RowUtil.merge(dbscanCenter.getGroupColNames(), row);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupDbscanModelBatchOp$SelectGroupAndClusterID.class */
    public class SelectGroupAndClusterID implements KeySelector<DbscanNewSample, Integer> {
        private static final long serialVersionUID = -8204871256389225863L;

        public SelectGroupAndClusterID() {
        }

        public Integer getKey(DbscanNewSample dbscanNewSample) {
            return Integer.valueOf(new MurmurHash3().arrayHash(new Integer[]{Integer.valueOf((int) dbscanNewSample.getClusterId()), Integer.valueOf(dbscanNewSample.getGroupHashKey())}, 0));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupDbscanModelBatchOp$getClusteringCenter.class */
    public static class getClusteringCenter implements GroupReduceFunction<DbscanNewSample, DbscanCenter<DenseVector>> {
        private static final long serialVersionUID = 6317085010066332931L;
        private int dim;
        private DistanceType distanceType;

        public getClusteringCenter(int i, DistanceType distanceType) {
            this.dim = i;
            this.distanceType = distanceType;
        }

        public void reduce(Iterable<DbscanNewSample> iterable, Collector<DbscanCenter<DenseVector>> collector) throws Exception {
            Iterator<DbscanNewSample> it = iterable.iterator();
            long j = 0;
            Row row = null;
            int i = 0;
            DenseVector denseVector = new DenseVector(this.dim);
            if (it.hasNext()) {
                DbscanNewSample next = it.next();
                j = next.getClusterId();
                row = next.getVec().getRows()[0];
                denseVector.plusEqual(next.getVec().getVector());
                i = 0 + 1;
            }
            if (j > -2147483648L) {
                while (it.hasNext()) {
                    denseVector.plusEqual(it.next().getVec().getVector());
                    i++;
                }
                denseVector.scaleEqual(1.0d / i);
                collector.collect(new DbscanCenter(row, j, i, denseVector));
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupDbscanModelBatchOp$mapToDataSample.class */
    public static class mapToDataSample implements MapFunction<Row, DbscanNewSample> {
        private static final long serialVersionUID = 1491814462425438888L;
        private int dim;
        private int groupColNamesSize;
        private FastDistance distance;

        public mapToDataSample(int i, int i2, FastDistance fastDistance) {
            this.dim = i;
            this.groupColNamesSize = i2;
            this.distance = fastDistance;
        }

        public DbscanNewSample map(Row row) throws Exception {
            String[] strArr = new String[this.groupColNamesSize];
            for (int i = 0; i < this.groupColNamesSize; i++) {
                strArr[i] = row.getField(i).toString();
            }
            double[] dArr = new double[this.dim];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr[i2] = ((Double) row.getField(i2 + this.groupColNamesSize)).doubleValue();
            }
            DenseVector denseVector = new DenseVector(dArr);
            Row row2 = new Row(this.groupColNamesSize);
            for (int i3 = 0; i3 < row2.getArity(); i3++) {
                row2.setField(i3, row.getField(i3));
            }
            return new DbscanNewSample(this.distance.prepareVectorData(Tuple2.of(denseVector, row2)), strArr);
        }
    }

    public GroupDbscanModelBatchOp() {
        this(new Params());
    }

    public GroupDbscanModelBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public GroupDbscanModelBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        if (!getParams().contains(HasPredictionCol.PREDICTION_COL)) {
            setPredictionCol("cluster_id");
        }
        DistanceType distanceType = getDistanceType();
        String[] strArr = (String[]) getParams().get(FEATURE_COLS);
        int intValue = ((Integer) getParams().get(MIN_POINTS)).intValue();
        Double d = (Double) getParams().get(EPSILON);
        String predictionCol = getPredictionCol();
        String[] strArr2 = (String[]) getParams().get(GROUP_COLS);
        for (String str : strArr2) {
            if (TableUtil.findColIndex(strArr, str) >= 0) {
                throw new RuntimeException("groupColNames should NOT be included in featureColNames!");
            }
        }
        String[] strArr3 = (String[]) ArrayUtils.addAll(strArr2, strArr);
        Preconditions.checkArgument(distanceType != DistanceType.JACCARD, "Not support %s!", new Object[]{distanceType.name()});
        FastDistance fastDistance = distanceType.getFastDistance();
        int length = strArr.length;
        String[] strArr4 = (String[]) ArrayUtils.addAll(ArrayUtils.addAll(strArr2, new String[]{predictionCol, DbscanConstant.COUNT}), strArr);
        setOutput((DataSet<Row>) checkAndGetFirst.select(strArr3).getDataSet().map(new mapToDataSample(length, strArr2.length, fastDistance)).groupBy(new GroupGeoDbscanBatchOp.SelectGroup()).reduceGroup(new GroupGeoDbscanBatchOp.Clustering(d.doubleValue(), intValue, fastDistance, getGroupMaxSamples().intValue(), getSkip().booleanValue())).groupBy(new SelectGroupAndClusterID()).reduceGroup(new getClusteringCenter(length, distanceType)).map(new MapToRow(strArr4.length, strArr2.length)), new TableSchema(strArr4, (TypeInformation[]) ArrayUtils.addAll(ArrayUtils.addAll(TableUtil.findColTypesWithAssert(checkAndGetFirst.getSchema(), strArr2), new TypeInformation[]{Types.LONG, Types.LONG}), TableUtil.findColTypesWithAssertAndHint(checkAndGetFirst.getSchema(), strArr))));
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ GroupDbscanModelBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
