package com.alibaba.alink.operator.batch.clustering;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalStateException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.Functional;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.clustering.GmmModelInfoBatchOp;
import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper;
import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp;
import com.alibaba.alink.operator.common.clustering.GmmClusterSummary;
import com.alibaba.alink.operator.common.clustering.GmmModelData;
import com.alibaba.alink.operator.common.clustering.GmmModelDataConverter;
import com.alibaba.alink.operator.common.clustering.kmeans.KMeansInitCentroids;
import com.alibaba.alink.operator.common.dataproc.FirstReducer;
import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary;
import com.alibaba.alink.operator.common.statistics.basicstatistic.MultivariateGaussian;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.clustering.GmmTrainParams;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.PartitionOperator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@ParamSelectColumnSpec(name = "vectorCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}, allowedTypeCollections = {TypeCollections.VECTOR_TYPES})
@NameCn("高斯混合模型训练")
@NameEn("GMM Training")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.clustering.GaussianMixture")
/* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GmmTrainBatchOp.class */
public final class GmmTrainBatchOp extends BatchOperator<GmmTrainBatchOp> implements GmmTrainParams<GmmTrainBatchOp>, WithModelInfoBatchOp<GmmModelInfoBatchOp.GmmModelInfo, GmmTrainBatchOp, GmmModelInfoBatchOp> {
    private static final Logger LOG = LoggerFactory.getLogger(GmmTrainBatchOp.class);
    private static final long serialVersionUID = 989850858114954550L;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GmmTrainBatchOp$IterationStatus.class */
    public static class IterationStatus implements Serializable {
        private static final long serialVersionUID = 6974710485320384520L;
        double prevLogLikelihood;
        double currLogLikelihood;

        private IterationStatus() {
        }

        public String toString() {
            return String.format("prev:%f,curr:%f", Double.valueOf(this.prevLogLikelihood), Double.valueOf(this.currLogLikelihood));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GmmTrainBatchOp$LocalAggregator.class */
    public static class LocalAggregator implements Serializable {
        private static final long serialVersionUID = -2375847154917039731L;
        int k;
        int featureSize;
        double prevLogLikelihood;
        transient DenseVector oldWeights;
        transient DenseVector[] oldMeans;
        transient DenseVector[] oldCovs;
        DenseVector updatedWeightsSum;
        DenseVector[] updatedMeansSum;
        DenseVector[] updatedCovsSum;
        transient MultivariateGaussian[] mnd;
        transient double[] prob;
        long totalCount = 0;
        double newLogLikelihood = Criteria.INVALID_GAIN;

        LocalAggregator(int i, int i2, double d, DenseVector denseVector, DenseVector[] denseVectorArr, DenseVector[] denseVectorArr2, MultivariateGaussian[] multivariateGaussianArr) {
            this.k = i;
            this.featureSize = i2;
            this.oldWeights = denseVector;
            this.oldMeans = denseVectorArr;
            this.oldCovs = denseVectorArr2;
            this.prevLogLikelihood = d;
            this.updatedWeightsSum = new DenseVector(i);
            this.updatedMeansSum = new DenseVector[i];
            this.updatedCovsSum = new DenseVector[i];
            this.mnd = multivariateGaussianArr;
            for (int i3 = 0; i3 < i; i3++) {
                this.updatedMeansSum[i3] = new DenseVector(i2);
                this.updatedCovsSum[i3] = new DenseVector(((i2 + 1) * i2) / 2);
            }
            this.prob = new double[i];
        }

        public void add(Vector vector) {
            double d = 0.0d;
            for (int i = 0; i < this.k; i++) {
                double pdf = this.oldWeights.get(i) * this.mnd[i].pdf(vector);
                this.prob[i] = pdf;
                d += pdf;
            }
            for (int i2 = 0; i2 < this.k; i2++) {
                double[] dArr = this.prob;
                int i3 = i2;
                dArr[i3] = dArr[i3] / d;
            }
            this.newLogLikelihood += Math.log(d);
            for (int i4 = 0; i4 < this.k; i4++) {
                this.updatedWeightsSum.add(i4, this.prob[i4]);
                this.updatedMeansSum[i4].plusScaleEqual(vector, this.prob[i4]);
                DenseVector denseVector = this.updatedCovsSum[i4];
                int i5 = 0;
                for (int i6 = 0; i6 < this.featureSize; i6++) {
                    for (int i7 = 0; i7 <= i6; i7++) {
                        denseVector.add(i5, vector.get(i6) * vector.get(i7) * this.prob[i4]);
                        i5++;
                    }
                }
            }
            this.totalCount++;
        }

        public LocalAggregator merge(LocalAggregator localAggregator) {
            this.totalCount += localAggregator.totalCount;
            this.updatedWeightsSum.plusEqual(localAggregator.updatedWeightsSum);
            for (int i = 0; i < this.k; i++) {
                this.updatedMeansSum[i].plusEqual(localAggregator.updatedMeansSum[i]);
                this.updatedCovsSum[i].plusEqual(localAggregator.updatedCovsSum[i]);
            }
            return this;
        }
    }

    public GmmTrainBatchOp() {
        this(new Params());
    }

    public GmmTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp
    public GmmModelInfoBatchOp getModelInfoBatchOp() {
        return new GmmModelInfoBatchOp(getParams()).linkFrom(this);
    }

    private static DataSet<Tuple3<Integer, GmmClusterSummary, IterationStatus>> initRandom(DataSet<Vector> dataSet, final int i, int i2) {
        return KMeansInitCentroids.selectTopK(i * 5, i2, dataSet, new Functional.SerializableFunction<Vector, byte[]>() { // from class: com.alibaba.alink.operator.batch.clustering.GmmTrainBatchOp.1
            private static final long serialVersionUID = -7473942125005406072L;

            @Override // java.util.function.Function
            public byte[] apply(Vector vector) {
                return vector.toBytes();
            }
        }).groupBy(new KeySelector<Tuple2<Long, Vector>, Integer>() { // from class: com.alibaba.alink.operator.batch.clustering.GmmTrainBatchOp.3
            private static final long serialVersionUID = -4882845811146695041L;

            public Integer getKey(Tuple2<Long, Vector> tuple2) {
                return Integer.valueOf((int) (((Long) tuple2.f0).longValue() % i));
            }
        }).reduceGroup(new GroupReduceFunction<Tuple2<Long, Vector>, Tuple3<Integer, GmmClusterSummary, IterationStatus>>() { // from class: com.alibaba.alink.operator.batch.clustering.GmmTrainBatchOp.2
            private static final long serialVersionUID = -5884722625126145531L;
            static final /* synthetic */ boolean $assertionsDisabled;

            public void reduce(Iterable<Tuple2<Long, Vector>> iterable, Collector<Tuple3<Integer, GmmClusterSummary, IterationStatus>> collector) {
                ArrayList arrayList = new ArrayList(5);
                int i3 = -1;
                int i4 = 0;
                for (Tuple2<Long, Vector> tuple2 : iterable) {
                    i3 = (int) (((Long) tuple2.f0).longValue() % i);
                    i4 = ((Vector) tuple2.f1).size();
                    arrayList.add(tuple2.f1);
                }
                if (!$assertionsDisabled && 5 != arrayList.size()) {
                    throw new AssertionError();
                }
                DenseVector denseVector = new DenseVector(i4);
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    denseVector.plusEqual((Vector) it.next());
                }
                denseVector.scaleEqual(0.2d);
                DenseVector denseVector2 = new DenseVector(i4);
                Iterator it2 = arrayList.iterator();
                while (it2.hasNext()) {
                    Vector minus = ((Vector) it2.next()).minus(denseVector);
                    for (int i5 = 0; i5 < i4; i5++) {
                        denseVector2.add(i5, minus.get(i5) * minus.get(i5));
                    }
                }
                denseVector2.scaleEqual(0.2d);
                DenseVector denseVector3 = new DenseVector((i4 * (i4 + 1)) / 2);
                for (int i6 = 0; i6 < i4; i6++) {
                    denseVector3.set(GmmModelData.getElementPositionInCompactMatrix(i6, i6, i4), denseVector2.get(i6));
                }
                collector.collect(Tuple3.of(Integer.valueOf(i3), new GmmClusterSummary(i3, 1.0d / i, denseVector, denseVector3), new IterationStatus()));
            }

            static {
                $assertionsDisabled = !GmmTrainBatchOp.class.desiredAssertionStatus();
            }
        }).name("init_model");
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public GmmTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        final String vectorCol = getVectorCol();
        final int intValue = getK().intValue();
        int intValue2 = getMaxIter().intValue();
        final double doubleValue = getEpsilon().doubleValue();
        Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> summaryHelper = StatisticsHelper.summaryHelper(checkAndGetFirst, null, vectorCol);
        MapOperator map = ((DataSet) summaryHelper.f1).map(new MapFunction<BaseVectorSummary, Integer>() { // from class: com.alibaba.alink.operator.batch.clustering.GmmTrainBatchOp.4
            private static final long serialVersionUID = 8456872852742625845L;

            public Integer map(BaseVectorSummary baseVectorSummary) {
                return Integer.valueOf(baseVectorSummary.vectorSize());
            }
        });
        SingleInputUdfOperator withBroadcastSet = ((DataSet) summaryHelper.f0).map(new RichMapFunction<Vector, Vector>() { // from class: com.alibaba.alink.operator.batch.clustering.GmmTrainBatchOp.5
            private static final long serialVersionUID = -845795862675993897L;
            transient int featureSize;

            public void open(Configuration configuration) {
                this.featureSize = ((Integer) getRuntimeContext().getBroadcastVariable("featureSize").get(0)).intValue();
            }

            public Vector map(Vector vector) {
                if (vector instanceof SparseVector) {
                    ((SparseVector) vector).setSize(this.featureSize);
                }
                return vector;
            }
        }).withBroadcastSet(map, "featureSize");
        IterativeDataSet iterate = initRandom(withBroadcastSet, intValue, getRandomSeed().intValue()).iterate(intValue2);
        PartitionOperator partitionCustom = withBroadcastSet.mapPartition(new RichMapPartitionFunction<Vector, LocalAggregator>() { // from class: com.alibaba.alink.operator.batch.clustering.GmmTrainBatchOp.10
            private static final long serialVersionUID = 8356493076036649604L;
            transient DenseVector oldWeights;
            transient DenseVector[] oldMeans;
            transient DenseVector[] oldCovs;
            transient MultivariateGaussian[] mnd;

            public void open(Configuration configuration) {
                this.oldWeights = new DenseVector(intValue);
                this.oldMeans = new DenseVector[intValue];
                this.oldCovs = new DenseVector[intValue];
                this.mnd = new MultivariateGaussian[intValue];
            }

            public void mapPartition(Iterable<Vector> iterable, Collector<LocalAggregator> collector) {
                List broadcastVariable = getRuntimeContext().getBroadcastVariable("featureSize");
                double d = 0.0d;
                for (Tuple4 tuple4 : getRuntimeContext().getBroadcastVariable("oldModel")) {
                    int intValue3 = ((Integer) tuple4.f0).intValue();
                    GmmClusterSummary gmmClusterSummary = (GmmClusterSummary) tuple4.f1;
                    d = ((IterationStatus) tuple4.f2).currLogLikelihood;
                    this.oldWeights.set(intValue3, gmmClusterSummary.weight);
                    this.oldMeans[intValue3] = gmmClusterSummary.mean;
                    this.oldCovs[intValue3] = gmmClusterSummary.cov;
                    this.mnd[intValue3] = new MultivariateGaussian((MultivariateGaussian) tuple4.f3);
                }
                LocalAggregator localAggregator = new LocalAggregator(intValue, ((Integer) broadcastVariable.get(0)).intValue(), d, this.oldWeights, this.oldMeans, this.oldCovs, this.mnd);
                localAggregator.getClass();
                iterable.forEach(localAggregator::add);
                collector.collect(localAggregator);
            }
        }).withBroadcastSet(map, "featureSize").withBroadcastSet(iterate.mapPartition(new RichMapPartitionFunction<Tuple3<Integer, GmmClusterSummary, IterationStatus>, Tuple4<Integer, GmmClusterSummary, IterationStatus, MultivariateGaussian>>() { // from class: com.alibaba.alink.operator.batch.clustering.GmmTrainBatchOp.6
            private static final long serialVersionUID = -1937088240477952410L;

            public void mapPartition(Iterable<Tuple3<Integer, GmmClusterSummary, IterationStatus>> iterable, Collector<Tuple4<Integer, GmmClusterSummary, IterationStatus, MultivariateGaussian>> collector) {
                for (Tuple3<Integer, GmmClusterSummary, IterationStatus> tuple3 : iterable) {
                    DenseVector denseVector = ((GmmClusterSummary) tuple3.f1).mean;
                    collector.collect(Tuple4.of(tuple3.f0, tuple3.f1, tuple3.f2, new MultivariateGaussian(denseVector, GmmModelData.expandCovarianceMatrix(((GmmClusterSummary) tuple3.f1).cov, denseVector.size()))));
                }
            }
        }).withForwardedFields(new String[]{"f0;f1;f2"}), "oldModel").name("E-M_step").reduce(new ReduceFunction<LocalAggregator>() { // from class: com.alibaba.alink.operator.batch.clustering.GmmTrainBatchOp.9
            private static final long serialVersionUID = -6976429920344470952L;

            public LocalAggregator reduce(LocalAggregator localAggregator, LocalAggregator localAggregator2) {
                return localAggregator.merge(localAggregator2);
            }
        }).flatMap(new RichFlatMapFunction<LocalAggregator, Tuple3<Integer, GmmClusterSummary, IterationStatus>>() { // from class: com.alibaba.alink.operator.batch.clustering.GmmTrainBatchOp.8
            private static final long serialVersionUID = 6599047947335456972L;

            public void flatMap(LocalAggregator localAggregator, Collector<Tuple3<Integer, GmmClusterSummary, IterationStatus>> collector) {
                for (int i = 0; i < intValue; i++) {
                    double d = localAggregator.updatedWeightsSum.get(i);
                    localAggregator.updatedMeansSum[i].scaleEqual(1.0d / d);
                    localAggregator.updatedCovsSum[i].scaleEqual(1.0d / d);
                    GmmClusterSummary gmmClusterSummary = new GmmClusterSummary(i, d / localAggregator.totalCount, localAggregator.updatedMeansSum[i], localAggregator.updatedCovsSum[i]);
                    int size = gmmClusterSummary.mean.size();
                    for (int i2 = 0; i2 < size; i2++) {
                        for (int i3 = i2; i3 < size; i3++) {
                            gmmClusterSummary.cov.add(GmmModelData.getElementPositionInCompactMatrix(i2, i3, size), (-1.0d) * gmmClusterSummary.mean.get(i2) * gmmClusterSummary.mean.get(i3));
                        }
                    }
                    IterationStatus iterationStatus = new IterationStatus();
                    iterationStatus.prevLogLikelihood = localAggregator.prevLogLikelihood;
                    iterationStatus.currLogLikelihood = localAggregator.newLogLikelihood;
                    collector.collect(Tuple3.of(Integer.valueOf(i), gmmClusterSummary, iterationStatus));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((LocalAggregator) obj, (Collector<Tuple3<Integer, GmmClusterSummary, IterationStatus>>) collector);
            }
        }).partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.batch.clustering.GmmTrainBatchOp.7
            private static final long serialVersionUID = 1006932050560340472L;

            public int partition(Integer num, int i) {
                return num.intValue() % i;
            }
        }, 0);
        setOutput((DataSet<Row>) iterate.closeWith(partitionCustom, partitionCustom.reduceGroup(new FirstReducer(1)).flatMap(new RichFlatMapFunction<Tuple3<Integer, GmmClusterSummary, IterationStatus>, Boolean>() { // from class: com.alibaba.alink.operator.batch.clustering.GmmTrainBatchOp.11
            private static final long serialVersionUID = 6890280483282243057L;

            public void flatMap(Tuple3<Integer, GmmClusterSummary, IterationStatus> tuple3, Collector<Boolean> collector) {
                IterationStatus iterationStatus = (IterationStatus) tuple3.f2;
                int superstepNumber = getIterationRuntimeContext().getSuperstepNumber();
                double abs = Math.abs(iterationStatus.currLogLikelihood - iterationStatus.prevLogLikelihood);
                GmmTrainBatchOp.LOG.info("step {}, prevLogLikelihood {}, currLogLikelihood {}, diffLogLikelihood {}", new Object[]{Integer.valueOf(superstepNumber), Double.valueOf(iterationStatus.prevLogLikelihood), Double.valueOf(iterationStatus.currLogLikelihood), Double.valueOf(abs)});
                if (superstepNumber <= 1 || abs > doubleValue) {
                    collector.collect(false);
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple3<Integer, GmmClusterSummary, IterationStatus>) obj, (Collector<Boolean>) collector);
            }
        })).mapPartition(new RichMapPartitionFunction<Tuple3<Integer, GmmClusterSummary, IterationStatus>, Row>() { // from class: com.alibaba.alink.operator.batch.clustering.GmmTrainBatchOp.12
            private static final long serialVersionUID = -8411238421923712023L;
            transient int featureSize;

            public void open(Configuration configuration) {
                this.featureSize = ((Integer) getRuntimeContext().getBroadcastVariable("featureSize").get(0)).intValue();
            }

            /* JADX WARN: Multi-variable type inference failed */
            public void mapPartition(Iterable<Tuple3<Integer, GmmClusterSummary, IterationStatus>> iterable, Collector<Row> collector) {
                if (getRuntimeContext().getNumberOfParallelSubtasks() > 1) {
                    throw new AkIllegalStateException("parallelism is not 1 when saving model.");
                }
                GmmModelData gmmModelData = new GmmModelData();
                gmmModelData.k = intValue;
                gmmModelData.dim = this.featureSize;
                gmmModelData.vectorCol = vectorCol;
                gmmModelData.data = new ArrayList(intValue);
                for (Tuple3<Integer, GmmClusterSummary, IterationStatus> tuple3 : iterable) {
                    ((GmmClusterSummary) tuple3.f1).clusterId = ((Integer) tuple3.f0).intValue();
                    gmmModelData.data.add(tuple3.f1);
                }
                new GmmModelDataConverter().save(gmmModelData, collector);
            }
        }).setParallelism(1).withBroadcastSet(map, "featureSize"), new GmmModelDataConverter().getModelSchema());
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ GmmTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
