package com.alibaba.alink.operator.common.classification;

import com.alibaba.alink.common.exceptions.AkColumnNotFoundException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.linalg.BLAS;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.RichModelMapper;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.classification.NaiveBayesTextPredictParams;
import com.alibaba.alink.params.classification.NaiveBayesTextTrainParams;
import java.util.HashMap;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/classification/NaiveBayesTextModelMapper.class */
public class NaiveBayesTextModelMapper extends RichModelMapper {
    private static final long serialVersionUID = -1418316110985603945L;
    public String[] colNames;
    public String vectorColName;
    public int vectorIndex;
    public NaiveBayesTextModelData modelData;

    public NaiveBayesTextModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.vectorColName = null;
        this.colNames = tableSchema2.getFieldNames();
        if (params.contains(NaiveBayesTextPredictParams.VECTOR_COL)) {
            this.vectorColName = (String) params.get(NaiveBayesTextPredictParams.VECTOR_COL);
            this.vectorIndex = TableUtil.findColIndex(this.colNames, this.vectorColName);
            if (this.vectorIndex == -1) {
                throw new AkColumnNotFoundException("the predict vector is not in the predict data schema.");
            }
        }
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Object predictResult(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        return findMaxProbLabel(calculateProb(VectorUtil.getVector(slicedSelectedSample.get(this.vectorIndex))), this.modelData.labels);
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Tuple2<Object, String> predictResultDetail(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        double[] calculateProb = calculateProb(VectorUtil.getVector(slicedSelectedSample.get(this.vectorIndex)));
        return new Tuple2<>(findMaxProbLabel(calculateProb, this.modelData.labels), generateDetail(calculateProb, this.modelData.pi, this.modelData.labels));
    }

    private double[] multinomialCalculation(Vector vector) {
        DenseVector zeros = DenseVector.zeros(this.modelData.theta.numRows());
        DenseVector denseVector = new DenseVector(this.modelData.pi);
        if (vector instanceof DenseVector) {
            NaiveBayesBLASUtil.gemv(1.0d, this.modelData.theta, (DenseVector) vector, Criteria.INVALID_GAIN, zeros);
        } else {
            NaiveBayesBLASUtil.gemv(1.0d, this.modelData.theta, (SparseVector) vector, Criteria.INVALID_GAIN, zeros);
        }
        BLAS.axpy(1.0d, denseVector, zeros);
        return zeros.getData();
    }

    private double[] bernoulliCalculation(Vector vector) {
        int numRows = this.modelData.theta.numRows();
        int numCols = this.modelData.theta.numCols();
        DenseVector zeros = DenseVector.zeros(numRows);
        DenseVector denseVector = new DenseVector(this.modelData.pi);
        DenseVector denseVector2 = new DenseVector(this.modelData.phi);
        if (vector instanceof DenseVector) {
            DenseVector denseVector3 = (DenseVector) vector;
            for (int i = 0; i < numCols; i++) {
                double d = denseVector3.get(i);
                AkPreconditions.checkArgument(d == Criteria.INVALID_GAIN || d == 1.0d, "Bernoulli naive Bayes requires 0 or 1 feature values.");
            }
            if (numCols < denseVector3.size()) {
                double[] dArr = new double[numCols];
                System.arraycopy(denseVector3.getData(), 0, dArr, 0, numCols);
                denseVector3.setData(dArr);
            }
            NaiveBayesBLASUtil.gemv(1.0d, this.modelData.minMat, denseVector3, Criteria.INVALID_GAIN, zeros);
        } else {
            SparseVector sparseVector = (SparseVector) vector;
            int[] indices = sparseVector.getIndices();
            double[] values = sparseVector.getValues();
            for (int i2 = 0; i2 < indices.length && indices[i2] < numCols; i2++) {
                double d2 = values[i2];
                AkPreconditions.checkArgument(d2 == Criteria.INVALID_GAIN || d2 == 1.0d, "Bernoulli naive Bayes requires 0 or 1 feature values.");
            }
            NaiveBayesBLASUtil.gemv(1.0d, this.modelData.minMat, sparseVector, Criteria.INVALID_GAIN, zeros);
        }
        BLAS.axpy(1.0d, denseVector, zeros);
        BLAS.axpy(1.0d, denseVector2, zeros);
        return zeros.getData();
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.modelData = new NaiveBayesTextModelDataConverter().load(list);
        this.vectorColName = this.modelData.vectorColName;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static String generateDetail(double[] dArr, double[] dArr2, Object[] objArr) {
        double d = dArr[0];
        for (int i = 1; i < dArr.length; i++) {
            if (d < dArr[i]) {
                d = dArr[i];
            }
        }
        double d2 = 0.0d;
        for (double d3 : dArr) {
            d2 += Math.exp(d3 - d);
        }
        double log = d + Math.log(d2);
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = Math.exp(dArr[i2] - log);
        }
        int length = dArr2.length;
        HashMap hashMap = new HashMap(length);
        for (int i3 = 0; i3 < length; i3++) {
            hashMap.put(objArr[i3].toString(), Double.valueOf(dArr[i3]));
        }
        return JsonConverter.toJson(hashMap);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static Object findMaxProbLabel(double[] dArr, Object[] objArr) {
        Object obj = null;
        int length = dArr.length;
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < length; i++) {
            if (d < dArr[i]) {
                d = dArr[i];
                obj = objArr[i];
            }
        }
        return obj;
    }

    public double[] calculateProb(Vector vector) {
        return NaiveBayesTextTrainParams.ModelType.Multinomial.equals(this.modelData.modelType) ? multinomialCalculation(vector) : bernoulliCalculation(vector);
    }
}
