package com.alibaba.alink.operator.batch.clustering;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.ReservedColsWithSecondInputSpec;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp;
import com.alibaba.alink.operator.common.clustering.agnes.Agnes;
import com.alibaba.alink.operator.common.clustering.agnes.AgnesCluster;
import com.alibaba.alink.operator.common.clustering.agnes.AgnesModelInfoBatchOp;
import com.alibaba.alink.operator.common.clustering.agnes.AgnesSample;
import com.alibaba.alink.operator.common.clustering.agnes.Linkage;
import com.alibaba.alink.operator.common.distance.ContinuousDistance;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.operator.common.evaluation.EvaluationUtil;
import com.alibaba.alink.params.clustering.AgnesParams;
import com.alibaba.alink.params.shared.clustering.HasClusteringDistanceType;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.FlatMapOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT), @PortSpec(PortType.EVAL_METRICS)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "vectorCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}, allowedTypeCollections = {TypeCollections.VECTOR_TYPES}), @ParamSelectColumnSpec(name = "idCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR})})
@ReservedColsWithSecondInputSpec
@NameCn("Agnes")
@NameEn("Agnes")
/* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/AgnesBatchOp.class */
public final class AgnesBatchOp extends BatchOperator<AgnesBatchOp> implements AgnesParams<AgnesBatchOp>, WithModelInfoBatchOp<AgnesModelInfoBatchOp.AgnesModelSummary, AgnesBatchOp, AgnesModelInfoBatchOp> {
    private static final long serialVersionUID = -7069169801410116405L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/AgnesBatchOp$AgnesKernel.class */
    public static class AgnesKernel implements MapPartitionFunction<AgnesSample, AgnesCluster> {
        private static final long serialVersionUID = 886248302149838023L;
        private double distanceThreshold;
        private int k;
        private ContinuousDistance distance;
        private Linkage linkage;

        public AgnesKernel(double d, int i, ContinuousDistance continuousDistance, Linkage linkage) {
            this.distanceThreshold = d;
            this.k = i;
            this.distance = continuousDistance;
            this.linkage = linkage;
        }

        public void mapPartition(Iterable<AgnesSample> iterable, Collector<AgnesCluster> collector) throws Exception {
            ArrayList arrayList = new ArrayList();
            Iterator<AgnesSample> it = iterable.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next());
            }
            Iterator<AgnesCluster> it2 = Agnes.startAnalysis(arrayList, this.k, this.distanceThreshold, this.linkage, this.distance).iterator();
            while (it2.hasNext()) {
                collector.collect(it2.next());
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/AgnesBatchOp$MergeInfo.class */
    public static class MergeInfo implements FlatMapFunction<AgnesCluster, Row> {
        private static final long serialVersionUID = 531203134457473817L;
        private TypeInformation idType;

        public MergeInfo(TypeInformation typeInformation) {
            this.idType = typeInformation;
        }

        public void flatMap(AgnesCluster agnesCluster, Collector<Row> collector) throws Exception {
            for (AgnesSample agnesSample : agnesCluster.getAgnesSamples()) {
                collector.collect(Row.of(new Object[]{EvaluationUtil.castTo(agnesSample.getSampleId(), this.idType), agnesSample.getMergeIter(), agnesSample.getParentId()}));
            }
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((AgnesCluster) obj, (Collector<Row>) collector);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/AgnesBatchOp$TransferClusterResult.class */
    public static class TransferClusterResult implements FlatMapFunction<AgnesCluster, Row> {
        private static final long serialVersionUID = 531203134457473817L;
        private long clusterId = 0;
        private TypeInformation idType;

        public TransferClusterResult(TypeInformation typeInformation) {
            this.idType = typeInformation;
        }

        public void flatMap(AgnesCluster agnesCluster, Collector<Row> collector) throws Exception {
            Iterator<AgnesSample> it = agnesCluster.getAgnesSamples().iterator();
            while (it.hasNext()) {
                collector.collect(Row.of(new Object[]{EvaluationUtil.castTo(it.next().getSampleId(), this.idType), Long.valueOf(this.clusterId)}));
            }
            this.clusterId++;
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((AgnesCluster) obj, (Collector<Row>) collector);
        }
    }

    public AgnesBatchOp() {
        super(null);
    }

    public AgnesBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public AgnesBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        int intValue = ((Integer) getParams().get(K)).intValue();
        double doubleValue = ((Double) getParams().get(DISTANCE_THRESHOLD)).doubleValue();
        HasClusteringDistanceType.DistanceType distanceType = (HasClusteringDistanceType.DistanceType) get(DISTANCE_TYPE);
        Linkage linkage = (Linkage) getParams().get(LINKAGE);
        FastDistance fastDistance = distanceType.getFastDistance();
        if (intValue <= 1 && doubleValue == Double.MAX_VALUE) {
            throw new RuntimeException("k should larger than 1,or distanceThreshold should be set");
        }
        TypeInformation<?> findColTypeWithAssertAndHint = TableUtil.findColTypeWithAssertAndHint(checkAndGetFirst.getSchema(), getIdCol());
        Operator parallelism = checkAndGetFirst.select(new String[]{getIdCol(), getVectorCol()}).getDataSet().map(new MapFunction<Row, AgnesSample>() { // from class: com.alibaba.alink.operator.batch.clustering.AgnesBatchOp.1
            private static final long serialVersionUID = -4667000522433310128L;

            public AgnesSample map(Row row) throws Exception {
                return new AgnesSample(row.getField(0).toString(), 0L, VectorUtil.getDenseVector(row.getField(1)), 1.0d);
            }
        }).mapPartition(new AgnesKernel(doubleValue, intValue, fastDistance, linkage)).setParallelism(1);
        FlatMapOperator flatMap = parallelism.flatMap(new TransferClusterResult(findColTypeWithAssertAndHint));
        FlatMapOperator flatMap2 = parallelism.flatMap(new MergeInfo(findColTypeWithAssertAndHint));
        TableSchema tableSchema = new TableSchema(new String[]{getIdCol(), getPredictionCol()}, new TypeInformation[]{findColTypeWithAssertAndHint, AlinkTypes.LONG});
        setOutput((DataSet<Row>) flatMap, tableSchema);
        setSideOutputTables(new Table[]{DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) flatMap2, new TableSchema(new String[]{"NodeId", "MergeIteration", "ParentId"}, new TypeInformation[]{findColTypeWithAssertAndHint, AlinkTypes.LONG, findColTypeWithAssertAndHint}))});
        setOutput((DataSet<Row>) flatMap, tableSchema);
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp
    public AgnesModelInfoBatchOp getModelInfoBatchOp() {
        return new AgnesModelInfoBatchOp(getParams()).linkFrom(this);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ AgnesBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
