package com.alibaba.alink.operator.batch.clustering;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.ReservedColsWithSecondInputSpec;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.RowUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.clustering.GroupGeoDbscanBatchOp;
import com.alibaba.alink.operator.common.clustering.DistanceType;
import com.alibaba.alink.operator.common.clustering.dbscan.DbscanConstant;
import com.alibaba.alink.operator.common.clustering.dbscan.DbscanNewSample;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.params.clustering.GroupDbscanParams;
import com.alibaba.alink.params.clustering.HasLatitudeCol;
import com.alibaba.alink.params.clustering.HasLongitudeCol;
import com.alibaba.alink.params.shared.colname.HasPredictionCol;
import java.util.Arrays;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "featureCols", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "groupCols", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "idCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR})})
@ReservedColsWithSecondInputSpec
@NameCn("分组Dbscan")
@NameEn("Group Dbscan")
/* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupDbscanBatchOp.class */
public class GroupDbscanBatchOp extends BatchOperator<GroupDbscanBatchOp> implements GroupDbscanParams<GroupDbscanBatchOp> {
    private static final long serialVersionUID = 2259660296918166445L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupDbscanBatchOp$MapToRow.class */
    public static class MapToRow extends RichMapFunction<DbscanNewSample, Row> {
        private static final long serialVersionUID = 4213429941831979236L;
        private int rowArity;
        private int groupColNamesSize;
        private Boolean isOutputVector;

        public MapToRow(Boolean bool, int i, int i2) {
            this.isOutputVector = bool;
            this.rowArity = i;
            this.groupColNamesSize = i2;
        }

        public Row map(DbscanNewSample dbscanNewSample) throws Exception {
            Row row = new Row((this.rowArity - this.groupColNamesSize) - 1);
            row.setField(0, dbscanNewSample.getType().name());
            row.setField(1, Long.valueOf(dbscanNewSample.getClusterId()));
            DenseVector denseVector = (DenseVector) dbscanNewSample.getVec().getVector();
            if (this.isOutputVector.booleanValue()) {
                row.setField(2, denseVector.toString());
            } else {
                for (int i = 0; i < denseVector.size(); i++) {
                    row.setField(i + 2, Double.valueOf(denseVector.get(i)));
                }
            }
            return RowUtil.merge(dbscanNewSample.getVec().getRows()[0], row);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupDbscanBatchOp$mapToDataVectorSample.class */
    public static class mapToDataVectorSample extends RichMapFunction<Row, DbscanNewSample> {
        private static final long serialVersionUID = -6733405177253139009L;
        private int dim;
        private int groupColNamesSize;
        private FastDistance distance;

        public mapToDataVectorSample(int i, int i2, FastDistance fastDistance) {
            this.dim = i;
            this.groupColNamesSize = i2;
            this.distance = fastDistance;
        }

        public DbscanNewSample map(Row row) throws Exception {
            Row row2 = new Row(row.getArity() - this.dim);
            String[] strArr = new String[this.groupColNamesSize];
            for (int i = 0; i < this.groupColNamesSize; i++) {
                strArr[i] = row.getField(i).toString();
                row2.setField(i, strArr[i]);
            }
            row2.setField(this.groupColNamesSize, row.getField(this.groupColNamesSize).toString());
            double[] dArr = new double[this.dim];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr[i2] = ((Number) row.getField(i2 + this.groupColNamesSize + 1)).doubleValue();
            }
            return new DbscanNewSample(this.distance.prepareVectorData(Tuple2.of(new DenseVector(dArr), row2)), strArr);
        }
    }

    public GroupDbscanBatchOp() {
        this(new Params());
    }

    public GroupDbscanBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public GroupDbscanBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        if (!getParams().contains(HasPredictionCol.PREDICTION_COL)) {
            setPredictionCol("cluster_id");
        }
        if (!getParams().contains(GroupDbscanParams.ID_COL)) {
            setIdCol("append_id");
        }
        Boolean bool = (Boolean) getParams().get(IS_OUTPUT_VECTOR);
        DistanceType distanceType = getDistanceType();
        String str = getParams().contains(HasLatitudeCol.LATITUDE_COL) ? (String) getParams().get(HasLatitudeCol.LATITUDE_COL) : null;
        String str2 = getParams().contains(HasLongitudeCol.LONGITUDE_COL) ? (String) getParams().get(HasLongitudeCol.LONGITUDE_COL) : null;
        String[] strArr = (!DistanceType.HAVERSINE.equals(distanceType) || str == null || str2 == null) ? (String[]) get(FEATURE_COLS) : new String[]{str, str2};
        int intValue = ((Integer) getParams().get(MIN_POINTS)).intValue();
        Double d = (Double) getParams().get(EPSILON);
        String str3 = (String) getParams().get(ID_COL);
        String predictionCol = getPredictionCol();
        String[] groupCols = getGroupCols();
        Preconditions.checkArgument(distanceType != DistanceType.JACCARD, "Not support Jaccard Distance!");
        FastDistance fastDistance = distanceType.getFastDistance();
        if (distanceType.equals(DistanceType.HAVERSINE)) {
            if ((strArr == null || strArr.length != 2) && (str == null || str2 == null || str.isEmpty() || str2.isEmpty())) {
                throw new RuntimeException("latitudeColName and longitudeColName should be set !");
            }
        } else if (strArr == null || strArr.length == 0) {
            throw new RuntimeException("featureColNames should be set !");
        }
        for (String str4 : groupCols) {
            if (TableUtil.findColIndex(strArr, str4) >= 0) {
                throw new RuntimeException("groupColNames should NOT be included in featureColNames!");
            }
        }
        if (null == str3 || "".equals(str3)) {
            throw new RuntimeException("idCol column should be set!");
        }
        if (TableUtil.findColIndex(strArr, str3) >= 0) {
            throw new RuntimeException("idCol column should NOT be included in featureColNames !");
        }
        if (TableUtil.findColIndex(groupCols, str3) >= 0) {
            throw new RuntimeException("idCol column should NOT be included in groupColNames !");
        }
        int length = strArr.length;
        String[] strArr2 = (String[]) ArrayUtils.addAll(ArrayUtils.addAll(groupCols, new String[]{str3}), strArr);
        String[] strArr3 = (String[]) ArrayUtils.addAll(ArrayUtils.addAll(groupCols, new String[]{str3, DbscanConstant.TYPE, predictionCol}), new String[0]);
        String[] strArr4 = bool.booleanValue() ? (String[]) ArrayUtils.add(strArr3, DbscanConstant.FEATURE_COL_NAMES) : (String[]) ArrayUtils.addAll(strArr3, strArr);
        TypeInformation[] typeInformationArr = new TypeInformation[strArr4.length];
        Arrays.fill(typeInformationArr, Types.STRING);
        typeInformationArr[groupCols.length + 2] = Types.LONG;
        if (!bool.booleanValue()) {
            Arrays.fill(typeInformationArr, groupCols.length + 3, typeInformationArr.length, Types.DOUBLE);
        }
        setOutput((DataSet<Row>) checkAndGetFirst.select(strArr2).getDataSet().map(new mapToDataVectorSample(length, groupCols.length, fastDistance)).groupBy(new GroupGeoDbscanBatchOp.SelectGroup()).reduceGroup(new GroupGeoDbscanBatchOp.Clustering(d.doubleValue(), intValue, fastDistance, getGroupMaxSamples().intValue(), getSkip().booleanValue())).map(new MapToRow(bool, strArr4.length, groupCols.length)), new TableSchema(strArr4, typeInformationArr));
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ GroupDbscanBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
