package com.alibaba.alink.operator.batch.classification;

import com.alibaba.alink.common.io.filesystem.copy.csv.CsvInputFormat;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.common.classification.NaiveBayesModelDataConverter;
import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData;
import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelDataConverter;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils;
import com.alibaba.alink.params.shared.colname.HasSelectedCols;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/batch/classification/NaiveBayesModelInfo.class */
public class NaiveBayesModelInfo implements Serializable {
    private static final long serialVersionUID = -1471696058725316172L;
    private String[] featureNames;
    private int featureSize;
    private boolean[] isCategorical;
    private double[] labelWeights;
    private Object[] labels;
    private int labelSize;
    public double[][] weightSum;
    public SparseVector[][] featureInfo;
    public transient List<Row> stringIndexerModelSerialized;
    private HashMap<Object, HashSet<Object>> cateFeatureValue;

    public NaiveBayesModelInfo() {
    }

    public NaiveBayesModelInfo(String[] strArr, boolean[] zArr, double[] dArr, Object[] objArr, double[][] dArr2, SparseVector[][] sparseVectorArr, List<Row> list) {
        this.featureNames = strArr;
        this.featureSize = strArr.length;
        this.cateFeatureValue = new HashMap<>(this.featureSize);
        this.isCategorical = zArr;
        this.labelWeights = dArr;
        this.labels = objArr;
        this.labelSize = objArr.length;
        this.weightSum = dArr2;
        this.featureInfo = sparseVectorArr;
        this.stringIndexerModelSerialized = list;
    }

    public NaiveBayesModelInfo(List<Row> list) {
        NaiveBayesModelInfo loadModelInfo = new NaiveBayesModelDataConverter().loadModelInfo(list);
        this.featureNames = loadModelInfo.featureNames;
        this.featureSize = this.featureNames.length;
        this.cateFeatureValue = new HashMap<>(this.featureSize);
        this.isCategorical = loadModelInfo.isCategorical;
        this.labelWeights = loadModelInfo.labelWeights;
        this.labels = loadModelInfo.labels;
        this.labelSize = this.labels.length;
        this.weightSum = loadModelInfo.weightSum;
        this.featureInfo = loadModelInfo.featureInfo;
        this.stringIndexerModelSerialized = loadModelInfo.stringIndexerModelSerialized;
    }

    public String[] getFeatureNames() {
        return this.featureNames;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public HashMap<Object, HashMap<Object, HashMap<Object, Double>>> getCategoryFeatureInfo() {
        MultiStringIndexerModelData load = new MultiStringIndexerModelDataConverter().load(this.stringIndexerModelSerialized);
        if (load.meta == null || !load.meta.contains(HasSelectedCols.SELECTED_COLS)) {
            return new HashMap<>(0);
        }
        HashMap hashMap = new HashMap(this.labelSize);
        String[] strArr = (String[]) load.meta.get(HasSelectedCols.SELECTED_COLS);
        int length = strArr.length;
        HashMap[] hashMapArr = new HashMap[length];
        for (int i = 0; i < length; i++) {
            hashMapArr[i] = new HashMap((int) load.getNumberOfTokensOfColumn(strArr[i]));
        }
        for (Tuple3<Integer, String, Long> tuple3 : load.tokenAndIndex) {
            hashMapArr[((Integer) tuple3.f0).intValue()].put(tuple3.f2, tuple3.f1);
        }
        int i2 = 0;
        for (int i3 = 0; i3 < this.featureSize; i3++) {
            if (this.isCategorical[i3]) {
                String str = this.featureNames[i3];
                HashSet<Object> hashSet = new HashSet<>();
                double[] dArr = new double[Math.toIntExact(load.getNumberOfTokensOfColumn(strArr[i2]))];
                for (int i4 = 0; i4 < this.labelSize; i4++) {
                    SparseVector sparseVector = this.featureInfo[i4][i3];
                    int[] indices = sparseVector.getIndices();
                    double[] values = sparseVector.getValues();
                    int length2 = indices.length;
                    for (int i5 = 0; i5 < length2; i5++) {
                        int i6 = indices[i5];
                        dArr[i6] = dArr[i6] + values[i5];
                    }
                }
                for (int i7 = 0; i7 < this.labelSize; i7++) {
                    SparseVector sparseVector2 = this.featureInfo[i7][i3];
                    int[] indices2 = sparseVector2.getIndices();
                    double[] values2 = sparseVector2.getValues();
                    int length3 = indices2.length;
                    HashMap hashMap2 = !hashMap.containsKey(this.labels[i7]) ? new HashMap() : (HashMap) hashMap.get(this.labels[i7]);
                    HashMap hashMap3 = new HashMap();
                    for (int i8 = 0; i8 < length3; i8++) {
                        Object obj = hashMapArr[i2].get(Long.valueOf(indices2[i8]));
                        hashSet.add(obj);
                        hashMap3.put(obj, Double.valueOf(values2[i8] / dArr[indices2[i8]]));
                    }
                    hashMap2.put(str, hashMap3);
                    hashMap.put(this.labels[i7], hashMap2);
                }
                i2++;
                this.cateFeatureValue.put(str, hashSet);
            }
        }
        ArrayList<String> arrayList = new ArrayList();
        for (int i9 = 0; i9 < this.featureSize; i9++) {
            if (this.isCategorical[i9]) {
                arrayList.add(this.featureNames[i9]);
            }
        }
        HashMap<Object, HashMap<Object, HashMap<Object, Double>>> hashMap4 = new HashMap<>(this.featureSize);
        for (String str2 : arrayList) {
            HashMap hashMap5 = new HashMap(this.labelSize);
            for (Object obj2 : this.labels) {
                hashMap5.put(obj2, ((HashMap) hashMap.get(obj2)).get(str2));
            }
            hashMap4.put(str2, hashMap5);
        }
        return hashMap4;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v13, types: [double[][]] */
    /* JADX WARN: Type inference failed for: r0v14 */
    /* JADX WARN: Type inference failed for: r0v20, types: [double[]] */
    public HashMap<Object, double[][]> getGaussFeatureInfo() {
        HashMap<Object, double[][]> hashMap = new HashMap<>(this.labelSize);
        for (int i = 0; i < this.featureSize; i++) {
            if (!this.isCategorical[i]) {
                for (int i2 = 0; i2 < this.labelSize; i2++) {
                    double[][] dArr = !hashMap.containsKey(this.labels[i2]) ? new double[this.featureSize] : hashMap.get(this.labels[i2]);
                    dArr[i] = this.featureInfo[i2][i].getValues();
                    hashMap.put(this.labels[i2], dArr);
                }
            }
        }
        return hashMap;
    }

    public Object[] getLabelList() {
        return this.labels;
    }

    public Map<Comparable, Double> getLabelProportion() {
        normalizeArray(this.labelWeights);
        HashMap hashMap = new HashMap(this.labels.length);
        for (int i = 0; i < this.labels.length; i++) {
            hashMap.put((Comparable) this.labels[i], Double.valueOf(this.labelWeights[i]));
        }
        return hashMap;
    }

    public Map<String, Boolean> getCategoryInfo() {
        HashMap hashMap = new HashMap(this.featureNames.length);
        for (int i = 0; i < this.featureNames.length; i++) {
            hashMap.put(this.featureNames[i], Boolean.valueOf(this.isCategorical[i]));
        }
        return hashMap;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(PrettyDisplayUtils.displayHeadline("NaiveBayesModelInfo", '-') + CsvInputFormat.DEFAULT_LINE_DELIMITER);
        HashMap hashMap = new HashMap();
        hashMap.put("feature col names", JsonConverter.toJson(this.featureNames));
        hashMap.put("feature size", String.valueOf(this.featureSize));
        hashMap.put("labels", JsonConverter.toJson(this.labels));
        hashMap.put("label number", String.valueOf(this.labelSize));
        sb.append(PrettyDisplayUtils.displayHeadline("model meta info", '='));
        sb.append(PrettyDisplayUtils.displayMap(hashMap, 10, false) + CsvInputFormat.DEFAULT_LINE_DELIMITER);
        Map<Comparable, Double> labelProportion = getLabelProportion();
        List asList = Arrays.asList(labelProportion.values().toArray(new Double[0]));
        List asList2 = Arrays.asList(labelProportion.keySet().toArray(new Comparable[0]));
        sb.append(PrettyDisplayUtils.displayHeadline("label proportion information", '=') + CsvInputFormat.DEFAULT_LINE_DELIMITER);
        sb.append("label info:");
        sb.append(PrettyDisplayUtils.displayList(asList2, false) + CsvInputFormat.DEFAULT_LINE_DELIMITER);
        sb.append("proportion:");
        sb.append(PrettyDisplayUtils.displayList(asList, false) + CsvInputFormat.DEFAULT_LINE_DELIMITER);
        Map<String, Boolean> categoryInfo = getCategoryInfo();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (Map.Entry<String, Boolean> entry : categoryInfo.entrySet()) {
            if (entry.getValue().booleanValue()) {
                arrayList.add(entry.getKey());
            } else {
                arrayList2.add(entry.getKey());
            }
        }
        sb.append(PrettyDisplayUtils.displayHeadline("category information", '=') + CsvInputFormat.DEFAULT_LINE_DELIMITER);
        sb.append("categorical features: ");
        sb.append(PrettyDisplayUtils.displayList(arrayList, false) + CsvInputFormat.DEFAULT_LINE_DELIMITER);
        sb.append("gaussian features: ");
        sb.append(PrettyDisplayUtils.displayList(arrayList2, false) + CsvInputFormat.DEFAULT_LINE_DELIMITER);
        sb.append(PrettyDisplayUtils.displayHeadline("categorical features proportion information", '=') + CsvInputFormat.DEFAULT_LINE_DELIMITER);
        printCateFeatureInfo(sb);
        printMeanSigma(sb);
        return sb.toString();
    }

    private static void normalizeArray(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] / d;
        }
    }

    private void printCateFeatureInfo(StringBuilder sb) {
        HashMap<Object, HashMap<Object, HashMap<Object, Double>>> categoryFeatureInfo = getCategoryFeatureInfo();
        if (categoryFeatureInfo.size() == 0) {
            sb.append("There is no category feature.\n");
            return;
        }
        for (Map.Entry<Object, HashMap<Object, HashMap<Object, Double>>> entry : categoryFeatureInfo.entrySet()) {
            Object key = entry.getKey();
            Object[] array = this.cateFeatureValue.get(key).toArray();
            int length = array.length;
            String[] strArr = new String[length];
            for (int i = 0; i < length; i++) {
                strArr[i] = array[i].toString();
            }
            String[] strArr2 = new String[this.labelSize];
            for (int i2 = 0; i2 < this.labelSize; i2++) {
                strArr2[i2] = this.labels[i2].toString();
            }
            Double[][] dArr = new Double[this.labelSize][length];
            for (int i3 = 0; i3 < this.labelSize; i3++) {
                Arrays.fill(dArr[i3], Double.valueOf(Criteria.INVALID_GAIN));
                HashMap<Object, Double> hashMap = entry.getValue().get(this.labels[i3]);
                for (int i4 = 0; i4 < length; i4++) {
                    Object obj = array[i4];
                    if (hashMap.containsKey(obj)) {
                        dArr[i3][i4] = hashMap.get(obj);
                    }
                }
            }
            sb.append("The features proportion information of " + key.toString() + ":\n");
            sb.append(PrettyDisplayUtils.displayTable(dArr, this.labelSize, length, strArr2, strArr, null, 3, 3) + CsvInputFormat.DEFAULT_LINE_DELIMITER);
        }
    }

    private void printMeanSigma(StringBuilder sb) {
        sb.append(PrettyDisplayUtils.displayHeadline("continuous features mean sigma information", '=') + CsvInputFormat.DEFAULT_LINE_DELIMITER);
        HashMap<Object, double[][]> gaussFeatureInfo = getGaussFeatureInfo();
        if (gaussFeatureInfo.size() == 0) {
            sb.append("There is no continuous feature.\n");
            return;
        }
        int i = 0;
        for (boolean z : this.isCategorical) {
            if (!z) {
                i++;
            }
        }
        String[] strArr = new String[this.labelSize];
        for (int i2 = 0; i2 < this.labelSize; i2++) {
            strArr[i2] = this.labels[i2].toString();
        }
        String[] strArr2 = new String[this.featureSize];
        for (int i3 = 0; i3 < this.featureSize; i3++) {
            strArr2[i3] = this.featureNames[i3].toString();
        }
        Double[][] dArr = new Double[this.labelSize][i];
        Double[][] dArr2 = new Double[this.labelSize][i];
        for (int i4 = 0; i4 < this.labelSize; i4++) {
            Object obj = this.labels[i4];
            int i5 = 0;
            for (int i6 = 0; i6 < this.featureSize; i6++) {
                if (!this.isCategorical[i6]) {
                    double[] dArr3 = gaussFeatureInfo.get(obj)[i6];
                    dArr[i4][i5] = Double.valueOf(dArr3[0]);
                    dArr2[i4][i5] = Double.valueOf(dArr3[1]);
                    i5++;
                }
            }
        }
        sb.append("Mean of features of each label:\n");
        sb.append(PrettyDisplayUtils.displayTable(dArr, this.labelSize, i, strArr, strArr2, null, 3, 3) + CsvInputFormat.DEFAULT_LINE_DELIMITER);
        sb.append("Std of features of each label:\n");
        sb.append(PrettyDisplayUtils.displayTable(dArr2, this.labelSize, i, strArr, strArr2, null, 3, 3) + CsvInputFormat.DEFAULT_LINE_DELIMITER);
    }
}
