package com.alibaba.alink.operator.stream.clustering;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.ReservedColsWithFirstInputSpec;
import com.alibaba.alink.common.exceptions.AkIllegalStateException;
import com.alibaba.alink.common.io.directreader.DataBridge;
import com.alibaba.alink.common.io.directreader.DirectReader;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.OutputColsHelper;
import com.alibaba.alink.common.utils.RowCollector;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.clustering.kmeans.KMeansModelDataConverter;
import com.alibaba.alink.operator.common.clustering.kmeans.KMeansTrainModelData;
import com.alibaba.alink.operator.common.clustering.kmeans.KMeansUtil;
import com.alibaba.alink.operator.common.distance.ContinuousDistance;
import com.alibaba.alink.operator.common.modelstream.ModelStreamUtils;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.utils.DataStreamConversionUtil;
import com.alibaba.alink.params.clustering.StreamingKMeansParams;
import com.alibaba.alink.params.shared.colname.HasPredictionCol;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.functions.co.RichCoFlatMapFunction;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InputPorts(values = {@PortSpec(value = PortType.MODEL, opType = PortSpec.OpType.BATCH), @PortSpec(PortType.DATA), @PortSpec(value = PortType.DATA, isOptional = true)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT), @PortSpec(value = PortType.MODEL, desc = PortDesc.KMEANS_MODEL)})
@NameCn("流式K均值聚类")
@ReservedColsWithFirstInputSpec
@NameEn("Streaming Kmeans")
/* loaded from: input_file:com/alibaba/alink/operator/stream/clustering/StreamingKMeansStreamOp.class */
public final class StreamingKMeansStreamOp extends StreamOperator<StreamingKMeansStreamOp> implements StreamingKMeansParams<StreamingKMeansStreamOp> {
    private static final long serialVersionUID = -7631814863449716946L;
    private BatchOperator batchModel;

    /* loaded from: input_file:com/alibaba/alink/operator/stream/clustering/StreamingKMeansStreamOp$AllDataMerge.class */
    public static class AllDataMerge implements FlatMapFunction<Tuple3<DenseVector[], int[], Long>, Tuple2<DenseVector[], int[]>> {
        private static final long serialVersionUID = 6157375461626497956L;
        private static final Logger LOG = LoggerFactory.getLogger(AllDataMerge.class);
        private HashMap<Long, Tuple3<Long, DenseVector[], int[]>> map = new HashMap<>();
        private int taskNum;

        AllDataMerge(int i) {
            this.taskNum = i;
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void flatMap(Tuple3<DenseVector[], int[], Long> tuple3, Collector<Tuple2<DenseVector[], int[]>> collector) {
            if (this.map.containsKey(tuple3.f2)) {
                Tuple3<Long, DenseVector[], int[]> tuple32 = this.map.get(tuple3.f2);
                tuple32.f0 = Long.valueOf(((Long) tuple32.f0).longValue() + 1);
                for (int i = 0; i < ((DenseVector[]) tuple3.f0).length; i++) {
                    ((DenseVector[]) tuple32.f1)[i].plusEqual(((DenseVector[]) tuple3.f0)[i]);
                    int[] iArr = (int[]) tuple32.f2;
                    int i2 = i;
                    iArr[i2] = iArr[i2] + ((int[]) tuple3.f1)[i];
                }
                if (((Long) tuple32.f0).longValue() == this.taskNum) {
                    collector.collect(Tuple2.of(tuple32.f1, tuple32.f2));
                    this.map.remove(tuple3.f2);
                }
            } else if (1 == this.taskNum) {
                collector.collect(Tuple2.of(tuple3.f0, tuple3.f1));
            } else {
                this.map.put(tuple3.f2, Tuple3.of(1L, tuple3.f0, tuple3.f1));
            }
            LOG.info("MapHashSet: " + this.map.keySet());
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Tuple3<DenseVector[], int[], Long>) obj, (Collector<Tuple2<DenseVector[], int[]>>) collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/stream/clustering/StreamingKMeansStreamOp$CollectUpdateData.class */
    public static class CollectUpdateData extends RichFlatMapFunction<Row, Tuple3<DenseVector[], int[], Long>> {
        private static final long serialVersionUID = -4082462361419657027L;
        private static final Logger LOG = LoggerFactory.getLogger(CollectUpdateData.class);
        private DataBridge modelDataBridge;
        private String[] trainColNames;
        private long startTime;
        private long timeInterval;
        private long windowHashCode;
        private transient DenseVector[] sum;
        private transient int[] clusterCount;
        private transient KMeansTrainModelData modelData;
        private transient int[] colIdx;
        private transient ContinuousDistance distance;

        public void open(Configuration configuration) {
            this.modelData = StreamingKMeansStreamOp.initModel(this.modelDataBridge);
            this.startTime = System.currentTimeMillis();
            this.windowHashCode = 1L;
            this.sum = new DenseVector[this.modelData.params.k];
            for (int i = 0; i < this.modelData.params.k; i++) {
                this.sum[i] = DenseVector.zeros(this.modelData.params.vectorSize);
            }
            this.clusterCount = new int[this.modelData.params.k];
            this.distance = this.modelData.params.distanceType.getFastDistance();
            this.colIdx = KMeansUtil.getKmeansPredictColIdxs(this.modelData.params, this.trainColNames);
        }

        CollectUpdateData(DataBridge dataBridge, String[] strArr, long j) {
            this.modelDataBridge = dataBridge;
            this.trainColNames = strArr;
            this.timeInterval = j;
        }

        public void flatMap(Row row, Collector<Tuple3<DenseVector[], int[], Long>> collector) throws Exception {
            if (System.currentTimeMillis() - this.startTime <= this.timeInterval * this.windowHashCode * 1000) {
                Vector kMeansPredictVector = KMeansUtil.getKMeansPredictVector(this.colIdx, row);
                Tuple2<Integer, Double> closestClusterIndex = KMeansUtil.getClosestClusterIndex(this.modelData, kMeansPredictVector, this.distance);
                int[] iArr = this.clusterCount;
                int intValue = ((Integer) closestClusterIndex.f0).intValue();
                iArr[intValue] = iArr[intValue] + 1;
                this.sum[((Integer) closestClusterIndex.f0).intValue()].plusEqual(kMeansPredictVector);
                return;
            }
            LOG.info("TaskId: " + getRuntimeContext().getIndexOfThisSubtask() + ", TriggerHashCode: " + this.windowHashCode);
            DenseVector[] denseVectorArr = this.sum;
            int[] iArr2 = this.clusterCount;
            long j = this.windowHashCode;
            this.windowHashCode = j + 1;
            collector.collect(Tuple3.of(denseVectorArr, iArr2, Long.valueOf(j)));
            this.sum = new DenseVector[this.modelData.params.k];
            for (int i = 0; i < this.modelData.params.k; i++) {
                this.sum[i] = DenseVector.zeros(this.modelData.params.vectorSize);
            }
            this.clusterCount = new int[this.modelData.params.k];
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Row) obj, (Collector<Tuple3<DenseVector[], int[], Long>>) collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/stream/clustering/StreamingKMeansStreamOp$PredType.class */
    public enum PredType {
        PRED,
        PRED_CLUS,
        PRED_DIST,
        PRED_CLUS_DIST;

        static PredType fromInputs(Params params) {
            return (params.contains(StreamingKMeansParams.PREDICTION_CLUSTER_COL) || params.contains(StreamingKMeansParams.PREDICTION_DISTANCE_COL)) ? !params.contains(StreamingKMeansParams.PREDICTION_CLUSTER_COL) ? PRED_DIST : !params.contains(StreamingKMeansParams.PREDICTION_DISTANCE_COL) ? PRED_CLUS : PRED_CLUS_DIST : PRED;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/stream/clustering/StreamingKMeansStreamOp$PredictOp.class */
    public static class PredictOp extends RichCoFlatMapFunction<Row, KMeansTrainModelData, Row> {
        private static final long serialVersionUID = 1824350851125007591L;
        private DataBridge modelDataBridge;
        private String[] trainColNames;
        private OutputColsHelper outputColsHelper;
        private transient ContinuousDistance distance;
        private transient KMeansTrainModelData modelData;
        private transient int[] colIdx;
        private PredType predType;

        public void open(Configuration configuration) {
            this.modelData = StreamingKMeansStreamOp.initModel(this.modelDataBridge);
            this.distance = this.modelData.params.distanceType.getFastDistance();
            this.colIdx = KMeansUtil.getKmeansPredictColIdxs(this.modelData.params, this.trainColNames);
        }

        PredictOp(DataBridge dataBridge, String[] strArr, OutputColsHelper outputColsHelper, PredType predType) {
            this.outputColsHelper = outputColsHelper;
            this.trainColNames = strArr;
            this.modelDataBridge = dataBridge;
            this.predType = predType;
        }

        public void flatMap1(Row row, Collector<Row> collector) {
            Tuple2<Integer, Double> closestClusterIndex = KMeansUtil.getClosestClusterIndex(this.modelData, KMeansUtil.getKMeansPredictVector(this.colIdx, row), this.distance);
            long clusterId = this.modelData.getClusterId(((Integer) closestClusterIndex.f0).intValue());
            DenseVector clusterVector = this.modelData.getClusterVector(((Integer) closestClusterIndex.f0).intValue());
            switch (this.predType) {
                case PRED:
                    collector.collect(this.outputColsHelper.getResultRow(row, Row.of(new Object[]{Long.valueOf(clusterId)})));
                    return;
                case PRED_CLUS:
                    collector.collect(this.outputColsHelper.getResultRow(row, Row.of(new Object[]{Long.valueOf(clusterId), clusterVector})));
                    return;
                case PRED_DIST:
                    collector.collect(this.outputColsHelper.getResultRow(row, Row.of(new Object[]{Long.valueOf(clusterId), closestClusterIndex.f1})));
                    return;
                case PRED_CLUS_DIST:
                    collector.collect(this.outputColsHelper.getResultRow(row, Row.of(new Object[]{Long.valueOf(clusterId), clusterVector, closestClusterIndex.f1})));
                    return;
                default:
                    return;
            }
        }

        public void flatMap2(KMeansTrainModelData kMeansTrainModelData, Collector<Row> collector) {
            this.modelData = kMeansTrainModelData;
        }

        public /* bridge */ /* synthetic */ void flatMap2(Object obj, Collector collector) throws Exception {
            flatMap2((KMeansTrainModelData) obj, (Collector<Row>) collector);
        }

        public /* bridge */ /* synthetic */ void flatMap1(Object obj, Collector collector) throws Exception {
            flatMap1((Row) obj, (Collector<Row>) collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/stream/clustering/StreamingKMeansStreamOp$UpdateModelOp.class */
    public static class UpdateModelOp extends RichMapFunction<Tuple2<DenseVector[], int[]>, KMeansTrainModelData> implements CheckpointedFunction {
        private static final long serialVersionUID = 4161086998242051550L;
        private static final Logger LOG = LoggerFactory.getLogger(UpdateModelOp.class);
        private transient KMeansTrainModelData modelData;
        private DataBridge modelDataBridge;
        private double decayFactor;
        private transient ListState<Tuple2<String, List<String>>> modelState;

        public void open(Configuration configuration) {
            if (getRuntimeContext().getNumberOfParallelSubtasks() > 1) {
                throw new AkIllegalStateException("The parallelism of UpdateModelOp should be one.");
            }
        }

        UpdateModelOp(DataBridge dataBridge, double d) {
            this.modelDataBridge = dataBridge;
            this.decayFactor = d;
        }

        public void initializeState(FunctionInitializationContext functionInitializationContext) throws Exception {
            LOG.info("StreamingKMeans: initializeState");
            this.modelState = functionInitializationContext.getOperatorStateStore().getOperatorState(new ListStateDescriptor("StreamingKMeansModelState", TypeInformation.of(new TypeHint<Tuple2<String, List<String>>>() { // from class: com.alibaba.alink.operator.stream.clustering.StreamingKMeansStreamOp.UpdateModelOp.1
            })));
            if (!functionInitializationContext.isRestored()) {
                this.modelData = StreamingKMeansStreamOp.initModel(this.modelDataBridge);
                return;
            }
            for (Tuple2 tuple2 : (Iterable) this.modelState.get()) {
                LOG.info("Loading state ...");
                this.modelData = KMeansUtil.loadModelForTrain(Params.fromJson((String) tuple2.f0), (Iterable) tuple2.f1);
                LOG.info("Loading state ... OK");
            }
        }

        public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
            LOG.info("StreamingKMeans: snapshotState at checkpoint");
            Tuple2<Params, Iterable<String>> serializeModel = new KMeansModelDataConverter().serializeModel(this.modelData);
            ArrayList arrayList = new ArrayList();
            Iterable iterable = (Iterable) serializeModel.f1;
            arrayList.getClass();
            iterable.forEach((v1) -> {
                r1.add(v1);
            });
            Tuple2 of = Tuple2.of(((Params) serializeModel.f0).toJson(), arrayList);
            this.modelState.clear();
            this.modelState.add(of);
        }

        public KMeansTrainModelData map(Tuple2<DenseVector[], int[]> tuple2) throws Exception {
            updateModel((DenseVector[]) tuple2.f0, (int[]) tuple2.f1, this.decayFactor);
            return this.modelData;
        }

        void updateModel(DenseVector[] denseVectorArr, int[] iArr, double d) {
            for (int i = 0; i < this.modelData.centroids.size(); i++) {
                double clusterWeight = this.modelData.getClusterWeight(i) * d;
                double d2 = clusterWeight + iArr[i];
                this.modelData.setClusterWeight(i, this.modelData.getClusterWeight(i) + iArr[i]);
                double[] data = this.modelData.getClusterVector(i).getData();
                if (iArr[i] > 0) {
                    double d3 = clusterWeight / d2;
                    for (int i2 = 0; i2 < denseVectorArr[i].size(); i2++) {
                        data[i2] = (data[i2] * d3) + (denseVectorArr[i].get(i2) / d2);
                    }
                }
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/stream/clustering/StreamingKMeansStreamOp$pipeModel.class */
    public static class pipeModel implements FlatMapFunction<KMeansTrainModelData, Row> {
        private static final long serialVersionUID = -6252541197996341634L;

        public void flatMap(KMeansTrainModelData kMeansTrainModelData, Collector<Row> collector) throws Exception {
            RowCollector rowCollector = new RowCollector();
            new KMeansModelDataConverter().save(kMeansTrainModelData, rowCollector);
            List<Row> rows = rowCollector.getRows();
            Timestamp timestamp = new Timestamp(System.currentTimeMillis());
            for (Row row : rows) {
                int arity = row.getArity();
                Row row2 = new Row(arity + 2);
                row2.setField(0, timestamp);
                row2.setField(1, Long.valueOf(rows.size()));
                for (int i = 0; i < arity; i++) {
                    row2.setField(2 + i, row.getField(i));
                }
                collector.collect(row2);
            }
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((KMeansTrainModelData) obj, (Collector<Row>) collector);
        }
    }

    public StreamingKMeansStreamOp(BatchOperator batchOperator) {
        super(new Params());
        this.batchModel = batchOperator;
    }

    public StreamingKMeansStreamOp(BatchOperator batchOperator, Params params) {
        super(params);
        this.batchModel = batchOperator;
    }

    public static KMeansTrainModelData initModel(DataBridge dataBridge) {
        return KMeansUtil.transformPredictDataToTrainData(new KMeansModelDataConverter().load(DirectReader.directRead(dataBridge)));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public StreamingKMeansStreamOp linkFrom(StreamOperator<?>... streamOperatorArr) {
        checkMinOpSize(1, streamOperatorArr);
        StreamOperator<?> streamOperator = streamOperatorArr[0];
        StreamOperator<?> streamOperator2 = streamOperatorArr[0];
        if (streamOperatorArr.length > 1) {
            streamOperator2 = streamOperatorArr[1];
        }
        if (!getParams().contains(HasPredictionCol.PREDICTION_COL)) {
            setPredictionCol("cluster_id");
        }
        long longValue = ((Long) getParams().get(TIME_INTERVAL)).longValue();
        double pow = Math.pow(0.5d, longValue / ((Integer) getParams().get(HALF_LIFE)).intValue());
        DataStream<Row> dataStream = streamOperator.getDataStream();
        DataStream<Row> dataStream2 = streamOperator2.getDataStream();
        PredType fromInputs = PredType.fromInputs(getParams());
        OutputColsHelper outputColsHelper = null;
        switch (fromInputs) {
            case PRED:
                outputColsHelper = new OutputColsHelper(streamOperator2.getSchema(), new String[]{getPredictionCol()}, (TypeInformation<?>[]) new TypeInformation[]{Types.LONG}, getReservedCols());
                break;
            case PRED_CLUS:
                outputColsHelper = new OutputColsHelper(streamOperator2.getSchema(), new String[]{getPredictionCol(), getPredictionClusterCol()}, (TypeInformation<?>[]) new TypeInformation[]{Types.LONG, AlinkTypes.DENSE_VECTOR}, getReservedCols());
                break;
            case PRED_DIST:
                outputColsHelper = new OutputColsHelper(streamOperator2.getSchema(), new String[]{getPredictionCol(), getPredictionDistanceCol()}, (TypeInformation<?>[]) new TypeInformation[]{Types.LONG, Types.DOUBLE}, getReservedCols());
                break;
            case PRED_CLUS_DIST:
                outputColsHelper = new OutputColsHelper(streamOperator2.getSchema(), new String[]{getPredictionCol(), getPredictionClusterCol(), getPredictionDistanceCol()}, (TypeInformation<?>[]) new TypeInformation[]{Types.LONG, AlinkTypes.DENSE_VECTOR, Types.DOUBLE}, getReservedCols());
                break;
        }
        DataBridge collect = DirectReader.collect(this.batchModel);
        SingleOutputStreamOperator name = dataStream.flatMap(new CollectUpdateData(collect, streamOperator.getColNames(), longValue)).name("local_aggregate");
        SingleOutputStreamOperator parallelism = name.flatMap(new AllDataMerge(name.getParallelism())).name("global_aggregate").setParallelism(1).map(new UpdateModelOp(collect, pow)).name("update_model").setParallelism(1);
        setOutput((DataStream<Row>) dataStream2.connect(parallelism.broadcast()).flatMap(new PredictOp(collect, streamOperator2.getColNames(), outputColsHelper, fromInputs)).name("kmeans_prediction"), outputColsHelper.getResultSchema());
        setSideOutputTables(outputModel(parallelism, getMLEnvironmentId().longValue()));
        return this;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Table[] outputModel(DataStream<KMeansTrainModelData> dataStream, long j) {
        SingleOutputStreamOperator flatMap = dataStream.flatMap(new pipeModel());
        TableSchema modelSchema = new KMeansModelDataConverter().getModelSchema();
        TypeInformation[] typeInformationArr = new TypeInformation[modelSchema.getFieldTypes().length + 2];
        String[] strArr = new String[modelSchema.getFieldTypes().length + 2];
        strArr[0] = ModelStreamUtils.MODEL_STREAM_TIMESTAMP_COLUMN_NAME;
        strArr[1] = ModelStreamUtils.MODEL_STREAM_COUNT_COLUMN_NAME;
        typeInformationArr[0] = ModelStreamUtils.MODEL_STREAM_TIMESTAMP_COLUMN_TYPE;
        typeInformationArr[1] = ModelStreamUtils.MODEL_STREAM_COUNT_COLUMN_TYPE;
        for (int i = 0; i < modelSchema.getFieldTypes().length; i++) {
            typeInformationArr[i + 2] = modelSchema.getFieldTypes()[i];
            strArr[i + 2] = modelSchema.getFieldNames()[i];
        }
        return new Table[]{DataStreamConversionUtil.toTable(Long.valueOf(j), (DataStream<Row>) flatMap, new TableSchema(strArr, typeInformationArr))};
    }

    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public /* bridge */ /* synthetic */ StreamingKMeansStreamOp linkFrom(StreamOperator[] streamOperatorArr) {
        return linkFrom((StreamOperator<?>[]) streamOperatorArr);
    }
}
