package com.alibaba.alink.operator.batch.classification;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp;
import com.alibaba.alink.operator.common.classification.NaiveBayesModelData;
import com.alibaba.alink.operator.common.classification.NaiveBayesModelDataConverter;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.common.tree.Preprocessing;
import com.alibaba.alink.params.classification.NaiveBayesTrainParams;
import com.alibaba.alink.params.shared.colname.HasCategoricalCols;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import org.apache.flink.api.common.functions.AbstractRichFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "labelCol"), @ParamSelectColumnSpec(name = "categoricalCols", allowedTypeCollections = {TypeCollections.NAIVE_BAYES_CATEGORICAL_TYPES}), @ParamSelectColumnSpec(name = "featureCols"), @ParamSelectColumnSpec(name = "weightCol", allowedTypeCollections = {TypeCollections.DOUBLE_TYPE})})
@NameCn("朴素贝叶斯训练")
@NameEn("Naive Bayes Training")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.NaiveBayes")
/* loaded from: input_file:com/alibaba/alink/operator/batch/classification/NaiveBayesTrainBatchOp.class */
public class NaiveBayesTrainBatchOp extends BatchOperator<NaiveBayesTrainBatchOp> implements NaiveBayesTrainParams<NaiveBayesTrainBatchOp>, WithModelInfoBatchOp<NaiveBayesModelInfo, NaiveBayesTrainBatchOp, NaiveBayesModelInfoBatchOp> {
    private static final long serialVersionUID = 8812570988530317418L;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/classification/NaiveBayesTrainBatchOp$GenerateModel.class */
    public static class GenerateModel extends AbstractRichFunction implements MapPartitionFunction<Tuple3<Object, Double[], HashMap<Integer, Double>[]>, NaiveBayesModelData> {
        private static final long serialVersionUID = -8733125133076037943L;
        private int featureSize;
        private double smoothing;
        private String[] featureColNames;
        private TypeInformation labelType;
        private boolean[] isCate;
        private List<Row> stringIndexerModel;

        GenerateModel(double d, String[] strArr, TypeInformation typeInformation, boolean[] zArr) {
            this.smoothing = d;
            this.labelType = typeInformation;
            this.featureColNames = strArr;
            this.isCate = zArr;
            this.featureSize = strArr.length;
        }

        public void open(Configuration configuration) throws Exception {
            this.stringIndexerModel = getRuntimeContext().getBroadcastVariable("stringIndexerModel");
        }

        public void mapPartition(Iterable<Tuple3<Object, Double[], HashMap<Integer, Double>[]>> iterable, Collector<NaiveBayesModelData> collector) throws Exception {
            double[] dArr = new double[this.featureSize];
            ArrayList arrayList = new ArrayList();
            HashSet[] hashSetArr = new HashSet[this.featureSize];
            for (int i = 0; i < this.featureSize; i++) {
                hashSetArr[i] = new HashSet();
            }
            for (Tuple3<Object, Double[], HashMap<Integer, Double>[]> tuple3 : iterable) {
                arrayList.add(tuple3);
                for (int i2 = 0; i2 < this.featureSize; i2++) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + ((Double[]) tuple3.f1)[i2].doubleValue();
                    hashSetArr[i2].addAll(((HashMap[]) tuple3.f2)[i2].keySet());
                }
            }
            int[] iArr = new int[this.featureSize];
            double d = 0.0d;
            int size = arrayList.size();
            for (int i4 = 0; i4 < this.featureSize; i4++) {
                iArr[i4] = hashSetArr[i4].size();
                d += dArr[i4];
            }
            double log = Math.log(d + (size * this.smoothing));
            Number[][][] numberArr = new Number[size][this.featureSize];
            double[] dArr2 = new double[size];
            double[] dArr3 = new double[size];
            Object[] objArr = new Object[size];
            for (int i5 = 0; i5 < size; i5++) {
                HashMap[] hashMapArr = (HashMap[]) ((Tuple3) arrayList.get(i5)).f2;
                for (int i6 = 0; i6 < this.featureSize; i6++) {
                    int i7 = iArr[i6];
                    Number[] numberArr2 = new Number[i7];
                    if (this.isCate[i6]) {
                        double log2 = Math.log(((Double[]) ((Tuple3) arrayList.get(i5)).f1)[i6].doubleValue() + (this.smoothing * iArr[i6]));
                        for (int i8 = 0; i8 < i7; i8++) {
                            double d2 = 0.0d;
                            if (hashMapArr[i6].containsKey(Integer.valueOf(i8))) {
                                d2 = ((Double) hashMapArr[i6].get(Integer.valueOf(i8))).doubleValue();
                            }
                            numberArr2[i8] = Double.valueOf(Math.log(d2 + this.smoothing) - log2);
                        }
                    } else {
                        for (int i9 = 0; i9 < i7; i9++) {
                            numberArr2[i9] = (Number) hashMapArr[i6].get(Integer.valueOf(i9));
                        }
                    }
                    numberArr[i5][i6] = numberArr2;
                }
                objArr[i5] = ((Tuple3) arrayList.get(i5)).f0;
                double d3 = 0.0d;
                for (Double d4 : (Double[]) ((Tuple3) arrayList.get(i5)).f1) {
                    d3 += d4.doubleValue();
                }
                dArr3[i5] = d3;
                dArr2[i5] = Math.log(d3 + this.smoothing) - log;
            }
            NaiveBayesModelData naiveBayesModelData = new NaiveBayesModelData();
            naiveBayesModelData.featureNames = this.featureColNames;
            naiveBayesModelData.isCate = this.isCate;
            naiveBayesModelData.label = objArr;
            naiveBayesModelData.piArray = dArr2;
            naiveBayesModelData.labelWeights = dArr3;
            naiveBayesModelData.theta = numberArr;
            naiveBayesModelData.stringIndexerModelSerialized = this.stringIndexerModel;
            naiveBayesModelData.generateWeightAndNumbers(arrayList);
            collector.collect(naiveBayesModelData);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/classification/NaiveBayesTrainBatchOp$ReduceItem.class */
    public static class ReduceItem extends AbstractRichFunction implements GroupReduceFunction<Tuple3<Object, Double, Number[]>, Tuple3<Object, Double[], HashMap<Integer, Double>[]>> {
        private static final long serialVersionUID = 4271048412599201885L;
        boolean[] isCate;
        int featureSize;

        ReduceItem(boolean[] zArr) {
            this.isCate = zArr;
            this.featureSize = zArr.length;
        }

        public void reduce(Iterable<Tuple3<Object, Double, Number[]>> iterable, Collector<Tuple3<Object, Double[], HashMap<Integer, Double>[]>> collector) throws Exception {
            Double[] dArr = new Double[this.featureSize];
            Arrays.fill(dArr, Double.valueOf(Criteria.INVALID_GAIN));
            Object obj = null;
            HashMap[] hashMapArr = new HashMap[this.featureSize];
            for (int i = 0; i < this.featureSize; i++) {
                hashMapArr[i] = new HashMap(2);
            }
            for (Tuple3<Object, Double, Number[]> tuple3 : iterable) {
                obj = tuple3.f0;
                for (int i2 = 0; i2 < this.featureSize; i2++) {
                    if (((Number[]) tuple3.f2)[i2] != null) {
                        if (this.isCate[i2]) {
                            hashMapArr[i2].compute((Integer) ((Number[]) tuple3.f2)[i2], (num, d) -> {
                                return Double.valueOf(d == null ? ((Double) tuple3.f1).doubleValue() : d.doubleValue() + ((Double) tuple3.f1).doubleValue());
                            });
                        } else {
                            HashMap hashMap = hashMapArr[i2];
                            if (hashMap.containsKey(0)) {
                                hashMap.put(0, Double.valueOf(((Double) hashMap.get(0)).doubleValue() + (((Double) tuple3.f1).doubleValue() * ((Double) ((Number[]) tuple3.f2)[i2]).doubleValue())));
                                hashMap.put(1, Double.valueOf(((Double) hashMap.get(1)).doubleValue() + (((Double) tuple3.f1).doubleValue() * Math.pow(((Double) ((Number[]) tuple3.f2)[i2]).doubleValue(), 2.0d))));
                            } else {
                                hashMap.put(0, Double.valueOf(((Double) tuple3.f1).doubleValue() * ((Double) ((Number[]) tuple3.f2)[i2]).doubleValue()));
                                hashMap.put(1, Double.valueOf(((Double) tuple3.f1).doubleValue() * Math.pow(((Double) ((Number[]) tuple3.f2)[i2]).doubleValue(), 2.0d)));
                            }
                        }
                        int i3 = i2;
                        dArr[i3] = Double.valueOf(dArr[i3].doubleValue() + ((Double) tuple3.f1).doubleValue());
                    }
                }
            }
            for (int i4 = 0; i4 < this.featureSize; i4++) {
                if (!this.isCate[i4]) {
                    HashMap hashMap2 = hashMapArr[i4];
                    hashMap2.put(0, Double.valueOf(((Double) hashMap2.get(0)).doubleValue() / dArr[i4].doubleValue()));
                    hashMap2.put(1, Double.valueOf((((Double) hashMap2.get(1)).doubleValue() / dArr[i4].doubleValue()) - Math.pow(((Double) hashMap2.get(0)).doubleValue(), 2.0d)));
                }
            }
            collector.collect(Tuple3.of(obj, dArr, hashMapArr));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/classification/NaiveBayesTrainBatchOp$SelectLabel.class */
    public static class SelectLabel implements KeySelector<Tuple3<Object, Double, Number[]>, String> {
        private static final long serialVersionUID = -8656348545181937312L;

        private SelectLabel() {
        }

        public String getKey(Tuple3<Object, Double, Number[]> tuple3) {
            return tuple3.f0.toString();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/classification/NaiveBayesTrainBatchOp$Transform.class */
    public static class Transform implements MapPartitionFunction<Row, Tuple3<Object, Double, Number[]>> {
        private static final long serialVersionUID = 2035076744255855602L;
        private int labelColIndex;
        private int weightColIndex;
        private int[] featureColIndices;
        private int featureSize;

        Transform(int i, int i2, int[] iArr) {
            this.labelColIndex = i;
            this.weightColIndex = i2;
            this.featureColIndices = iArr;
            this.featureSize = iArr.length;
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Tuple3<Object, Double, Number[]>> collector) throws Exception {
            if (this.weightColIndex == -1) {
                for (Row row : iterable) {
                    Object field = row.getField(this.labelColIndex);
                    Number[] numberArr = new Number[this.featureSize];
                    for (int i = 0; i < this.featureSize; i++) {
                        numberArr[i] = (Number) row.getField(this.featureColIndices[i]);
                    }
                    collector.collect(Tuple3.of(field, Double.valueOf(1.0d), numberArr));
                }
                return;
            }
            for (Row row2 : iterable) {
                Object field2 = row2.getField(this.labelColIndex);
                Double d = (Double) row2.getField(this.weightColIndex);
                Number[] numberArr2 = new Number[this.featureSize];
                for (int i2 = 0; i2 < this.featureSize; i2++) {
                    numberArr2[i2] = (Number) row2.getField(this.featureColIndices[i2]);
                }
                collector.collect(Tuple3.of(field2, d, numberArr2));
            }
        }
    }

    public NaiveBayesTrainBatchOp() {
        super(new Params());
    }

    public NaiveBayesTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public NaiveBayesTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String labelCol = getLabelCol();
        final TypeInformation<?> typeInformation = checkAndGetFirst.getColTypes()[TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getColNames(), labelCol)];
        String[] featureCols = getFeatureCols();
        int length = featureCols.length;
        int[] findColIndices = TableUtil.findColIndices(checkAndGetFirst.getColNames(), featureCols);
        TypeInformation[] typeInformationArr = new TypeInformation[length];
        for (int i = 0; i < length; i++) {
            typeInformationArr[i] = checkAndGetFirst.getColTypes()[findColIndices[i]];
        }
        String weightCol = getWeightCol() == null ? "" : getWeightCol();
        double doubleValue = getSmoothing().doubleValue();
        boolean[] generateCategoricalCols = generateCategoricalCols(getParams().contains(HasCategoricalCols.CATEGORICAL_COLS) ? getCategoricalCols() : new String[0], featureCols, typeInformationArr, labelCol, getParams());
        BatchOperator<?> generateStringIndexerModel = Preprocessing.generateStringIndexerModel(checkAndGetFirst, getParams());
        BatchOperator<?> castContinuousCols = Preprocessing.castContinuousCols(Preprocessing.castCategoricalCols(checkAndGetFirst, generateStringIndexerModel, getParams()), getParams());
        setOutput((DataSet<Row>) castContinuousCols.getDataSet().mapPartition(new Transform(TableUtil.findColIndexWithAssertAndHint(castContinuousCols.getColNames(), labelCol), TableUtil.findColIndex(castContinuousCols.getColNames(), weightCol), findColIndices)).groupBy(new SelectLabel()).reduceGroup(new ReduceItem(generateCategoricalCols)).mapPartition(new GenerateModel(doubleValue, featureCols, typeInformation, generateCategoricalCols)).withBroadcastSet(generateStringIndexerModel.getDataSet(), "stringIndexerModel").setParallelism(1).flatMap(new FlatMapFunction<NaiveBayesModelData, Row>() { // from class: com.alibaba.alink.operator.batch.classification.NaiveBayesTrainBatchOp.1
            private static final long serialVersionUID = 4634517702171022715L;

            public void flatMap(NaiveBayesModelData naiveBayesModelData, Collector<Row> collector) throws Exception {
                new NaiveBayesModelDataConverter(typeInformation).save(naiveBayesModelData, collector);
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((NaiveBayesModelData) obj, (Collector<Row>) collector);
            }
        }).setParallelism(1), new NaiveBayesModelDataConverter(typeInformation).getModelSchema());
        return this;
    }

    private static boolean[] generateCategoricalCols(String[] strArr, String[] strArr2, TypeInformation<?>[] typeInformationArr, String str, Params params) {
        ArrayList arrayList = new ArrayList();
        int length = strArr2.length;
        boolean[] zArr = new boolean[length];
        for (int i = 0; i < length; i++) {
            String str2 = strArr2[i];
            if (!str2.equals(str)) {
                TypeInformation<?> typeInformation = typeInformationArr[i];
                if (checkCategorical(typeInformation) || ((typeInformation.equals(Types.BIG_INT) || typeInformation.equals(Types.INT) || typeInformation.equals(Types.LONG)) && TableUtil.findColIndex(strArr, str2) != -1)) {
                    arrayList.add(str2);
                    zArr[i] = true;
                } else if (TableUtil.findColIndex(strArr, str2) != -1) {
                    throw new AkIllegalOperatorParameterException("column \"" + str2 + "\"'s type is " + typeInformation + ", which is not categorical!");
                }
            }
        }
        params.set((ParamInfo<ParamInfo<String[]>>) HasCategoricalCols.CATEGORICAL_COLS, (ParamInfo<String[]>) arrayList.toArray(new String[0]));
        return zArr;
    }

    private static boolean checkCategorical(TypeInformation typeInformation) {
        if (typeInformation.equals(Types.STRING) || typeInformation.equals(Types.BOOLEAN)) {
            return true;
        }
        if (typeInformation.equals(Types.DOUBLE) || typeInformation.equals(Types.FLOAT) || typeInformation.equals(Types.BIG_INT) || typeInformation.equals(Types.LONG) || typeInformation.equals(Types.INT)) {
            return false;
        }
        throw new AkUnsupportedOperationException("don't support the type " + typeInformation);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp
    public NaiveBayesModelInfoBatchOp getModelInfoBatchOp() {
        return new NaiveBayesModelInfoBatchOp(getParams()).linkFrom(this);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ NaiveBayesTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
