package com.alibaba.alink.operator.batch.classification;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.clustering.KMeansTrainBatchOp;
import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper;
import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp;
import com.alibaba.alink.operator.common.classification.NaiveBayesTextModelData;
import com.alibaba.alink.operator.common.classification.NaiveBayesTextModelDataConverter;
import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.classification.NaiveBayesTextTrainParams;
import com.alibaba.alink.params.shared.colname.HasFeatureCols;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.util.ArrayList;
import org.apache.flink.api.common.functions.AbstractRichFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "labelCol"), @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES}), @ParamSelectColumnSpec(name = "weightCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES})})
@NameCn("朴素贝叶斯文本分类训练")
@NameEn("Naive Bayes Text Training")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.NaiveBayesTextClassifier")
/* loaded from: input_file:com/alibaba/alink/operator/batch/classification/NaiveBayesTextTrainBatchOp.class */
public class NaiveBayesTextTrainBatchOp extends BatchOperator<NaiveBayesTextTrainBatchOp> implements NaiveBayesTextTrainParams<NaiveBayesTextTrainBatchOp>, WithModelInfoBatchOp<NaiveBayesTextModelInfo, NaiveBayesTextTrainBatchOp, NaiveBayesTextModelInfoBatchOp> {
    private static final long serialVersionUID = 1343509041059789517L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/classification/NaiveBayesTextTrainBatchOp$GenerateModel.class */
    public static class GenerateModel extends AbstractRichFunction implements MapPartitionFunction<Tuple3<Object, Double, Vector>, Row> {
        private static final long serialVersionUID = -8763129628694985528L;
        private int numFeature;
        private double smoothing;
        private NaiveBayesTextTrainParams.ModelType modelType;
        private String vectorColName;
        private String[] featureCols;
        private TypeInformation labelType;

        GenerateModel(double d, NaiveBayesTextTrainParams.ModelType modelType, String str, String[] strArr, TypeInformation typeInformation) {
            this.smoothing = d;
            this.modelType = modelType;
            this.labelType = typeInformation;
            this.vectorColName = str;
            this.featureCols = strArr;
        }

        /* JADX WARN: Failed to find 'out' block for switch in B:10:0x00b1. Please report as an issue. */
        public void mapPartition(Iterable<Tuple3<Object, Double, Vector>> iterable, Collector<Row> collector) throws Exception {
            double d = 0.0d;
            ArrayList arrayList = new ArrayList();
            for (Tuple3<Object, Double, Vector> tuple3 : iterable) {
                d += ((Double) tuple3.f1).doubleValue();
                arrayList.add(Tuple3.of(tuple3.f0, tuple3.f1, (DenseVector) tuple3.f2));
            }
            int size = arrayList.size();
            double log = Math.log(d + (size * this.smoothing));
            DenseMatrix denseMatrix = new DenseMatrix(size, this.numFeature);
            double[] dArr = new double[size];
            Object[] objArr = new Object[size];
            for (int i = 0; i < size; i++) {
                DenseVector denseVector = (DenseVector) ((Tuple3) arrayList.get(i)).f2;
                double d2 = 0.0d;
                switch (this.modelType) {
                    case Multinomial:
                        double d3 = 0.0d;
                        for (int i2 = 0; i2 < denseVector.size(); i2++) {
                            d3 += denseVector.get(i2);
                        }
                        d2 = Criteria.INVALID_GAIN + Math.log(d3 + (this.numFeature * this.smoothing));
                        break;
                    case Bernoulli:
                        d2 = Criteria.INVALID_GAIN + Math.log(((Double) ((Tuple3) arrayList.get(i)).f1).doubleValue() + (2.0d * this.smoothing));
                        break;
                }
                objArr[i] = ((Tuple3) arrayList.get(i)).f0;
                dArr[i] = Math.log(((Double) ((Tuple3) arrayList.get(i)).f1).doubleValue() + this.smoothing) - log;
                for (int i3 = 0; i3 < denseVector.size(); i3++) {
                    denseMatrix.set(i, i3, Math.log(denseVector.get(i3) + this.smoothing) - d2);
                }
            }
            NaiveBayesTextModelData naiveBayesTextModelData = new NaiveBayesTextModelData();
            naiveBayesTextModelData.pi = dArr;
            naiveBayesTextModelData.labels = objArr;
            naiveBayesTextModelData.theta = denseMatrix;
            naiveBayesTextModelData.vectorColName = this.vectorColName;
            naiveBayesTextModelData.modelType = this.modelType;
            naiveBayesTextModelData.featureColNames = this.featureCols;
            naiveBayesTextModelData.vectorSize = ((DenseVector) ((Tuple3) arrayList.get(0)).f2).size();
            new NaiveBayesTextModelDataConverter(this.labelType).save(naiveBayesTextModelData, collector);
        }

        public void open(Configuration configuration) throws Exception {
            this.numFeature = ((Integer) getRuntimeContext().getBroadcastVariable(KMeansTrainBatchOp.VECTOR_SIZE).get(0)).intValue();
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/classification/NaiveBayesTextTrainBatchOp$ReduceItem.class */
    public static class ReduceItem extends AbstractRichFunction implements GroupReduceFunction<Tuple3<Object, Double, Vector>, Tuple3<Object, Double, Vector>> {
        private static final long serialVersionUID = -3644173529201603819L;
        private int vectorSize = 0;

        public void reduce(Iterable<Tuple3<Object, Double, Vector>> iterable, Collector<Tuple3<Object, Double, Vector>> collector) {
            Object obj = null;
            double d = 0.0d;
            DenseVector denseVector = new DenseVector(this.vectorSize);
            for (Tuple3<Object, Double, Vector> tuple3 : iterable) {
                obj = tuple3.f0;
                double doubleValue = ((Double) tuple3.f1).doubleValue();
                d += doubleValue;
                if (tuple3.f2 instanceof SparseVector) {
                    ((SparseVector) tuple3.f2).setSize(this.vectorSize);
                    int[] indices = ((SparseVector) tuple3.f2).getIndices();
                    double[] values = ((SparseVector) tuple3.f2).getValues();
                    for (int i = 0; i < indices.length; i++) {
                        denseVector.add(indices[i], values[i] * doubleValue);
                    }
                } else {
                    for (int i2 = 0; i2 < this.vectorSize; i2++) {
                        denseVector.set(i2, denseVector.get(i2) + (((Vector) tuple3.f2).get(i2) * doubleValue));
                    }
                }
            }
            collector.collect(new Tuple3(obj, Double.valueOf(d), denseVector));
        }

        public void open(Configuration configuration) throws Exception {
            this.vectorSize = ((Integer) getRuntimeContext().getBroadcastVariable(KMeansTrainBatchOp.VECTOR_SIZE).get(0)).intValue();
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/classification/NaiveBayesTextTrainBatchOp$SelectLabel.class */
    public static class SelectLabel implements KeySelector<Tuple3<Object, Double, Vector>, String> {
        private static final long serialVersionUID = -3406893536674801451L;

        public String getKey(Tuple3<Object, Double, Vector> tuple3) {
            return tuple3.f0.toString();
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/classification/NaiveBayesTextTrainBatchOp$Transform.class */
    public static class Transform implements MapPartitionFunction<Tuple2<Vector, Row>, Tuple3<Object, Double, Vector>> {
        private static final long serialVersionUID = -1962725988758297056L;

        public void mapPartition(Iterable<Tuple2<Vector, Row>> iterable, Collector<Tuple3<Object, Double, Vector>> collector) throws Exception {
            for (Tuple2<Vector, Row> tuple2 : iterable) {
                collector.collect(new Tuple3(((Row) tuple2.f1).getArity() == 2 ? ((Row) tuple2.f1).getField(1) : ((Row) tuple2.f1).getField(0), Double.valueOf(((Row) tuple2.f1).getArity() == 2 ? ((Row) tuple2.f1).getField(0) instanceof Number ? ((Number) ((Row) tuple2.f1).getField(0)).doubleValue() : Double.parseDouble(((Row) tuple2.f1).getField(0).toString()) : 1.0d), (Vector) tuple2.f0));
            }
        }
    }

    public NaiveBayesTextTrainBatchOp() {
        super(new Params());
    }

    public NaiveBayesTextTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public NaiveBayesTextTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String labelCol = getLabelCol();
        NaiveBayesTextTrainParams.ModelType modelType = getModelType();
        String weightCol = getWeightCol();
        double doubleValue = getSmoothing().doubleValue();
        String vectorCol = getVectorCol();
        TypeInformation<?> findColTypeWithAssertAndHint = TableUtil.findColTypeWithAssertAndHint(checkAndGetFirst.getSchema(), labelCol);
        Tuple2<DataSet<Tuple2<Vector, Row>>, DataSet<BaseVectorSummary>> summaryHelper = StatisticsHelper.summaryHelper(checkAndGetFirst, null, vectorCol, weightCol == null ? new String[]{labelCol} : new String[]{weightCol, labelCol});
        DataSet dataSet = (DataSet) summaryHelper.f0;
        MapOperator map = ((DataSet) summaryHelper.f1).map(new MapFunction<BaseVectorSummary, Integer>() { // from class: com.alibaba.alink.operator.batch.classification.NaiveBayesTextTrainBatchOp.1
            private static final long serialVersionUID = -4626037497952553113L;

            public Integer map(BaseVectorSummary baseVectorSummary) {
                return Integer.valueOf(baseVectorSummary.vectorSize());
            }
        });
        SingleInputUdfOperator withBroadcastSet = dataSet.mapPartition(new Transform()).groupBy(new SelectLabel()).reduceGroup(new ReduceItem()).withBroadcastSet(map, KMeansTrainBatchOp.VECTOR_SIZE);
        String[] strArr = null;
        if (getParams().contains(HasFeatureCols.FEATURE_COLS)) {
            strArr = (String[]) getParams().get(HasFeatureCols.FEATURE_COLS);
        }
        setOutput((DataSet<Row>) withBroadcastSet.mapPartition(new GenerateModel(doubleValue, modelType, vectorCol, strArr, findColTypeWithAssertAndHint)).withBroadcastSet(map, KMeansTrainBatchOp.VECTOR_SIZE).setParallelism(1), new NaiveBayesTextModelDataConverter(findColTypeWithAssertAndHint).getModelSchema());
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp
    public NaiveBayesTextModelInfoBatchOp getModelInfoBatchOp() {
        return new NaiveBayesTextModelInfoBatchOp(getParams()).linkFrom(this);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ NaiveBayesTextTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
