package com.alibaba.alink.operator.stream.clustering;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.ReservedColsWithSecondInputSpec;
import com.alibaba.alink.common.io.directreader.DataBridge;
import com.alibaba.alink.common.io.directreader.DirectReader;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.utils.OutputColsHelper;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.clustering.kmeans.KMeansTrainModelData;
import com.alibaba.alink.operator.common.clustering.kmeans.KMeansUtil;
import com.alibaba.alink.operator.common.distance.ContinuousDistance;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.params.clustering.OnePassClusterParams;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.types.Row;

@InputPorts(values = {@PortSpec(value = PortType.MODEL, desc = PortDesc.KMEANS_MODEL, opType = PortSpec.OpType.BATCH), @PortSpec(value = PortType.DATA, desc = PortDesc.PREDICT_DATA, opType = PortSpec.OpType.STREAM)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.PREDICT_RESULT, opType = PortSpec.OpType.STREAM), @PortSpec(value = PortType.MODEL, desc = PortDesc.KMEANS_MODEL, opType = PortSpec.OpType.STREAM)})
@ReservedColsWithSecondInputSpec
@NameCn("一趟聚类")
@NameEn("One Pass Cluster")
/* loaded from: input_file:com/alibaba/alink/operator/stream/clustering/OnePassClusterStreamOp.class */
public final class OnePassClusterStreamOp extends StreamOperator<OnePassClusterStreamOp> implements OnePassClusterParams<OnePassClusterStreamOp> {
    private static final long serialVersionUID = -9023400083161571185L;
    BatchOperator batchModel;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/stream/clustering/OnePassClusterStreamOp$OnePassCluster.class */
    public static class OnePassCluster extends RichMapFunction<Row, Tuple2<Row, KMeansTrainModelData>> {
        private static final long serialVersionUID = -8013967247690938140L;
        private OutputColsHelper outputColsHelper;
        private KMeansTrainModelData modelData;
        private DataBridge dataBridge;
        private String[] colNames;
        private int[] colIdx;
        private double epsilon;
        private ContinuousDistance distance;
        private Integer modelOutputInterval;
        private int cnt = 0;

        public void open(Configuration configuration) {
            this.modelData = StreamingKMeansStreamOp.initModel(this.dataBridge);
            this.distance = this.modelData.params.distanceType.getFastDistance();
            this.colIdx = KMeansUtil.getKmeansPredictColIdxs(this.modelData.params, this.colNames);
        }

        OnePassCluster(DataBridge dataBridge, OutputColsHelper outputColsHelper, String[] strArr, double d, Integer num) {
            this.epsilon = d;
            this.outputColsHelper = outputColsHelper;
            this.modelOutputInterval = num;
            this.dataBridge = dataBridge;
            this.colNames = strArr;
        }

        public Tuple2<Row, KMeansTrainModelData> map(Row row) throws Exception {
            Row of;
            Vector kMeansPredictVector = KMeansUtil.getKMeansPredictVector(this.colIdx, row);
            Tuple2<Integer, Double> closestClusterIndex = KMeansUtil.getClosestClusterIndex(this.modelData, kMeansPredictVector, this.distance);
            if (((Double) closestClusterIndex.f1).doubleValue() < this.epsilon) {
                DenseVector clusterVector = this.modelData.getClusterVector(((Integer) closestClusterIndex.f0).intValue());
                double clusterWeight = this.modelData.getClusterWeight(((Integer) closestClusterIndex.f0).intValue());
                this.modelData.setClusterWeight(((Integer) closestClusterIndex.f0).intValue(), clusterWeight + 1.0d);
                clusterVector.scaleEqual(clusterWeight / (clusterWeight + 1.0d));
                clusterVector.plusScaleEqual(kMeansPredictVector, 1.0d / clusterWeight);
                of = Row.of(new Object[]{Long.valueOf(((Integer) closestClusterIndex.f0).intValue()), closestClusterIndex.f1});
            } else {
                int size = this.modelData.centroids.size();
                this.modelData.centroids.add(new KMeansTrainModelData.ClusterSummary(kMeansPredictVector instanceof SparseVector ? ((SparseVector) kMeansPredictVector).toDenseVector() : (DenseVector) kMeansPredictVector, size, 1.0d));
                this.modelData.params.k = this.modelData.centroids.size();
                of = Row.of(new Object[]{Long.valueOf(size), Double.valueOf(Criteria.INVALID_GAIN)});
            }
            if (this.modelOutputInterval == null) {
                return Tuple2.of(this.outputColsHelper.getResultRow(row, of), (Object) null);
            }
            this.cnt++;
            if (this.cnt != this.modelOutputInterval.intValue()) {
                return Tuple2.of(this.outputColsHelper.getResultRow(row, of), (Object) null);
            }
            this.cnt = 0;
            return Tuple2.of(this.outputColsHelper.getResultRow(row, of), this.modelData);
        }
    }

    public OnePassClusterStreamOp(BatchOperator batchOperator) {
        super(new Params());
        this.batchModel = null;
        this.batchModel = batchOperator;
    }

    public OnePassClusterStreamOp(BatchOperator batchOperator, Params params) {
        super(params);
        this.batchModel = null;
        this.batchModel = batchOperator;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public OnePassClusterStreamOp linkFrom(StreamOperator<?>... streamOperatorArr) {
        checkOpSize(1, streamOperatorArr);
        OutputColsHelper outputColsHelper = new OutputColsHelper(streamOperatorArr[0].getSchema(), new String[]{getPredictionCol(), getPredictionDetailCol()}, (TypeInformation<?>[]) new TypeInformation[]{Types.LONG, Types.DOUBLE}, getReservedCols());
        SingleOutputStreamOperator parallelism = streamOperatorArr[0].getDataStream().map(new OnePassCluster(DirectReader.collect(this.batchModel), outputColsHelper, streamOperatorArr[0].getSchema().getFieldNames(), getEpsilon().doubleValue(), getModelOutputInterval())).setParallelism(1);
        SingleOutputStreamOperator map = parallelism.filter(new FilterFunction<Tuple2<Row, KMeansTrainModelData>>() { // from class: com.alibaba.alink.operator.stream.clustering.OnePassClusterStreamOp.2
            private static final long serialVersionUID = -4346461233209536122L;

            public boolean filter(Tuple2<Row, KMeansTrainModelData> tuple2) throws Exception {
                return tuple2.f0 != null;
            }
        }).map(new MapFunction<Tuple2<Row, KMeansTrainModelData>, Row>() { // from class: com.alibaba.alink.operator.stream.clustering.OnePassClusterStreamOp.1
            private static final long serialVersionUID = 7997807036805771891L;

            public Row map(Tuple2<Row, KMeansTrainModelData> tuple2) throws Exception {
                return (Row) tuple2.f0;
            }
        });
        setSideOutputTables(StreamingKMeansStreamOp.outputModel(parallelism.filter(new FilterFunction<Tuple2<Row, KMeansTrainModelData>>() { // from class: com.alibaba.alink.operator.stream.clustering.OnePassClusterStreamOp.4
            private static final long serialVersionUID = 4166084406562047700L;

            public boolean filter(Tuple2<Row, KMeansTrainModelData> tuple2) throws Exception {
                return tuple2.f1 != null;
            }
        }).map(new MapFunction<Tuple2<Row, KMeansTrainModelData>, KMeansTrainModelData>() { // from class: com.alibaba.alink.operator.stream.clustering.OnePassClusterStreamOp.3
            private static final long serialVersionUID = -5162839831147241620L;

            public KMeansTrainModelData map(Tuple2<Row, KMeansTrainModelData> tuple2) throws Exception {
                return (KMeansTrainModelData) tuple2.f1;
            }
        }), getMLEnvironmentId().longValue()));
        setOutput((DataStream<Row>) map, outputColsHelper.getResultSchema());
        return this;
    }

    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public /* bridge */ /* synthetic */ OnePassClusterStreamOp linkFrom(StreamOperator[] streamOperatorArr) {
        return linkFrom((StreamOperator<?>[]) streamOperatorArr);
    }
}
