package com.alibaba.alink.operator.common.clustering;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.RichModelMapper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.statistics.basicstatistic.MultivariateGaussian;
import com.alibaba.alink.params.clustering.GmmPredictParams;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/GmmModelMapper.class */
public class GmmModelMapper extends RichModelMapper {
    private static final long serialVersionUID = -4999832537099548829L;
    private int vectorColIdx;
    private GmmModelData modelData;
    private MultivariateGaussian[] multivariateGaussians;
    private transient ThreadLocal<double[]> threadLocalProb;

    public GmmModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.vectorColIdx = TableUtil.findColIndexWithAssertAndHint(tableSchema2.getFieldNames(), (String) this.params.get(GmmPredictParams.VECTOR_COL));
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Object predictResult(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        return predictResultDetail(slicedSelectedSample).f0;
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Tuple2<Object, String> predictResultDetail(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        Vector vector = VectorUtil.getVector(slicedSelectedSample.get(this.vectorColIdx));
        double[] dArr = this.threadLocalProb.get();
        int i = this.modelData.k;
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            double pdf = this.modelData.data.get(i2).weight * this.multivariateGaussians[i2].pdf(vector);
            dArr[i2] = pdf;
            d += pdf;
        }
        for (int i3 = 0; i3 < i; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] / d;
        }
        int i5 = 0;
        double d2 = dArr[0];
        for (int i6 = 1; i6 < i; i6++) {
            if (dArr[i6] > d2) {
                d2 = dArr[i6];
                i5 = i6;
            }
        }
        return Tuple2.of(Long.valueOf(i5), VectorUtil.serialize(new DenseVector(dArr)));
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected TypeInformation<?> initPredResultColType(TableSchema tableSchema) {
        return Types.LONG;
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.modelData = new GmmModelDataConverter().load(list);
        this.multivariateGaussians = new MultivariateGaussian[this.modelData.k];
        for (int i = 0; i < this.modelData.k; i++) {
            this.multivariateGaussians[i] = new MultivariateGaussian(this.modelData.data.get(i).mean, GmmModelData.expandCovarianceMatrix(this.modelData.data.get(i).cov, this.modelData.dim));
        }
        this.threadLocalProb = ThreadLocal.withInitial(() -> {
            return new double[this.modelData.k];
        });
    }
}
