package com.alibaba.alink.operator.common.tree;

import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.io.filesystem.copy.csv.CsvInputFormat;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData;
import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelDataConverter;
import com.alibaba.alink.operator.common.tree.TreeUtil;
import com.alibaba.alink.operator.common.tree.parallelcart.BaseGbdtTrainBatchOp;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.LossUtils;
import com.alibaba.alink.operator.common.tree.viz.TreeModelViz;
import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils;
import com.alibaba.alink.params.classification.RandomForestTrainParams;
import java.io.IOException;
import java.io.Serializable;
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/TreeModelInfo.class */
public abstract class TreeModelInfo implements Serializable {
    private static final long serialVersionUID = 316584854790096878L;
    static final ParamInfo<String> FEATURE_IMPORTANCE = ParamInfoFactory.createParamInfo("featureImportance", String.class).build();
    TreeModelDataConverter dataConverter;
    MultiStringIndexerModelData multiStringIndexerModelData;
    Map<String, Double> featureImportance;
    boolean isRegressionTree;

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/TreeModelInfo$DecisionTreeModelInfo.class */
    public static class DecisionTreeModelInfo extends TreeModelInfo {
        private static final long serialVersionUID = -3670502627904480174L;

        public DecisionTreeModelInfo(List<Row> list) {
            super(list);
        }

        public String getCaseWhenRule() {
            AkPreconditions.checkArgument(this.dataConverter.roots.length == 1, "This is not a decision tree model. length: %d", Integer.valueOf(this.dataConverter.roots.length));
            return getCaseWhenRuleFromTreeId(0);
        }

        public TreeModelInfo saveTreeAsImage(String str, boolean z) throws IOException {
            AkPreconditions.checkArgument(this.dataConverter.roots.length == 1, "This is not a decision tree model. length: %d", Integer.valueOf(this.dataConverter.roots.length));
            return saveTreeAsImageFromTreeId(str, 0, z);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/TreeModelInfo$GbdtModelInfo.class */
    public static final class GbdtModelInfo extends MultiTreeModelInfo {
        private static final long serialVersionUID = -859598180490206967L;

        public GbdtModelInfo(List<Row> list) {
            super(list);
        }

        @Override // com.alibaba.alink.operator.common.tree.TreeModelInfo.MultiTreeModelInfo
        public /* bridge */ /* synthetic */ TreeModelInfo saveTreeAsImage(String str, int i, boolean z) throws IOException {
            return super.saveTreeAsImage(str, i, z);
        }

        @Override // com.alibaba.alink.operator.common.tree.TreeModelInfo.MultiTreeModelInfo
        public /* bridge */ /* synthetic */ String getCaseWhenRule(int i) {
            return super.getCaseWhenRule(i);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/TreeModelInfo$MultiTreeModelInfo.class */
    static class MultiTreeModelInfo extends TreeModelInfo {
        private static final long serialVersionUID = -8437257628657305619L;

        public MultiTreeModelInfo(List<Row> list) {
            super(list);
        }

        public String getCaseWhenRule(int i) {
            return getCaseWhenRuleFromTreeId(i);
        }

        public TreeModelInfo saveTreeAsImage(String str, int i, boolean z) throws IOException {
            return saveTreeAsImageFromTreeId(str, i, z);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/TreeModelInfo$RandomForestModelInfo.class */
    public static final class RandomForestModelInfo extends MultiTreeModelInfo {
        private static final long serialVersionUID = -6423403615369604045L;

        public RandomForestModelInfo(List<Row> list) {
            super(list);
        }

        @Override // com.alibaba.alink.operator.common.tree.TreeModelInfo.MultiTreeModelInfo
        public /* bridge */ /* synthetic */ TreeModelInfo saveTreeAsImage(String str, int i, boolean z) throws IOException {
            return super.saveTreeAsImage(str, i, z);
        }

        @Override // com.alibaba.alink.operator.common.tree.TreeModelInfo.MultiTreeModelInfo
        public /* bridge */ /* synthetic */ String getCaseWhenRule(int i) {
            return super.getCaseWhenRule(i);
        }
    }

    /* JADX WARN: Type inference failed for: r2v3, types: [com.alibaba.alink.operator.common.tree.TreeModelInfo$1] */
    public TreeModelInfo(List<Row> list) {
        this.dataConverter = new TreeModelDataConverter().load(list);
        if (this.dataConverter.stringIndexerModelSerialized != null) {
            this.multiStringIndexerModelData = new MultiStringIndexerModelDataConverter().load(this.dataConverter.stringIndexerModelSerialized);
        }
        if (this.dataConverter.meta.contains(FEATURE_IMPORTANCE)) {
            this.featureImportance = (Map) JsonConverter.fromJson((String) this.dataConverter.meta.get(FEATURE_IMPORTANCE), new TypeReference<Map<String, Double>>() { // from class: com.alibaba.alink.operator.common.tree.TreeModelInfo.1
            }.getType());
        }
        this.isRegressionTree = isRegressionTree();
    }

    private boolean isRegressionTree() {
        if (this.dataConverter.meta.contains(LossUtils.LOSS_TYPE) || this.dataConverter.meta.contains(BaseGbdtTrainBatchOp.ALGO_TYPE)) {
            return true;
        }
        return Criteria.isRegression((TreeUtil.TreeType) this.dataConverter.meta.get(TreeUtil.TREE_TYPE));
    }

    protected String getCaseWhenRuleFromTreeId(int i) {
        AkPreconditions.checkArgument(i >= 0 && i < this.dataConverter.roots.length, "treeId should be in range [0, %d), treeId: %d", Integer.valueOf(this.dataConverter.roots.length), Integer.valueOf(i));
        if (getFeatures() == null) {
            return null;
        }
        StringBuilder sb = new StringBuilder();
        appendNode(this.dataConverter.roots[i], getFeatures(), sb);
        return sb.toString();
    }

    protected TreeModelInfo saveTreeAsImageFromTreeId(String str, int i, boolean z) throws IOException {
        AkPreconditions.checkArgument(i >= 0 && i < this.dataConverter.roots.length, "treeId should be in range [0, %d), treeId: %d", Integer.valueOf(this.dataConverter.roots.length), Integer.valueOf(i));
        TreeModelViz.toImageFile(str, this.dataConverter, i, z);
        return this;
    }

    public Map<String, Double> getFeatureImportance() {
        return this.featureImportance;
    }

    public int getNumTrees() {
        return ((Integer) this.dataConverter.meta.get(RandomForestTrainParams.NUM_TREES)).intValue();
    }

    public String[] getFeatures() {
        if (this.dataConverter.meta.contains(RandomForestTrainParams.FEATURE_COLS)) {
            return (String[]) this.dataConverter.meta.get(RandomForestTrainParams.FEATURE_COLS);
        }
        return null;
    }

    public String[] getCategoricalFeatures() {
        if (this.dataConverter.meta.contains(RandomForestTrainParams.CATEGORICAL_COLS)) {
            return (String[]) this.dataConverter.meta.get(RandomForestTrainParams.CATEGORICAL_COLS);
        }
        return null;
    }

    public List<String> getCategoricalValues(String str) {
        if (this.multiStringIndexerModelData != null) {
            return this.multiStringIndexerModelData.getTokens(str);
        }
        return null;
    }

    public Object[] getLabels() {
        return this.dataConverter.labels;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        if (this.isRegressionTree) {
            sb.append("Regression trees modelInfo: \n");
        } else {
            sb.append("Classification trees modelInfo: \n");
        }
        sb.append("Number of trees: ").append(getNumTrees()).append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
        String[] categoricalFeatures = getCategoricalFeatures();
        if (getFeatures() != null) {
            sb.append("Number of features: ").append(getFeatures().length).append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
            sb.append("Number of categorical features: ").append((categoricalFeatures == null || categoricalFeatures.length == 0) ? 0 : categoricalFeatures.length).append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
        }
        if (getLabels() != null) {
            sb.append("Labels: ");
            sb.append(PrettyDisplayUtils.displayList(Arrays.asList(getLabels())));
            sb.append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
        }
        if (categoricalFeatures != null && categoricalFeatures.length > 0) {
            sb.append("\nCategorical feature info:\n");
            Object[][] objArr = new Object[categoricalFeatures.length][2];
            for (int i = 0; i < categoricalFeatures.length; i++) {
                List<String> categoricalValues = getCategoricalValues(categoricalFeatures[i]);
                int i2 = i;
                Object[] objArr2 = new Object[2];
                objArr2[0] = categoricalFeatures[i];
                objArr2[1] = Integer.valueOf(categoricalValues == null ? 0 : categoricalValues.size());
                objArr[i2] = objArr2;
            }
            sb.append(PrettyDisplayUtils.displayTable(objArr, categoricalFeatures.length, 2, null, new String[]{"feature", "number of categorical value"}, null));
        }
        if (getFeatureImportance() != null && !getFeatureImportance().isEmpty()) {
            Map<String, Double> featureImportance = getFeatureImportance();
            int min = Math.min(featureImportance.size(), 50);
            sb.append("\nTable of feature importance Top ").append(min).append(": ").append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
            Object[][] objArr3 = new Object[featureImportance.size()][2];
            int i3 = 0;
            for (Map.Entry<String, Double> entry : featureImportance.entrySet()) {
                Object[] objArr4 = new Object[2];
                objArr4[0] = entry.getKey();
                objArr4[1] = entry.getValue();
                objArr3[i3] = objArr4;
                i3++;
            }
            Arrays.sort(objArr3, (objArr5, objArr6) -> {
                return Double.compare(((Double) objArr6[1]).doubleValue(), ((Double) objArr5[1]).doubleValue());
            });
            Object[][] objArr7 = (Object[][]) ArrayUtils.subarray(objArr3, 0, min);
            sb.append(PrettyDisplayUtils.displayTable(objArr7, objArr7.length, 2, null, new String[]{"feature", "importance"}, null, min, 0, 2, false));
        }
        return sb.toString();
    }

    private void appendNode(Node node, String[] strArr, StringBuilder sb) {
        if (node.isLeaf()) {
            if (this.isRegressionTree) {
                sb.append(printEightDecimal(node.getCounter().getDistributions()[0]));
                return;
            }
            double d = 0.0d;
            int i = -1;
            for (int i2 = 0; i2 < node.getCounter().getDistributions().length; i2++) {
                if (d < node.getCounter().getDistributions()[i2]) {
                    d = node.getCounter().getDistributions()[i2];
                    i = i2;
                }
            }
            AkPreconditions.checkArgument(i >= 0, "Can not find the probability: {}", JsonConverter.toJson(node.getCounter().getDistributions()));
            sb.append(this.dataConverter.labels[i]);
            return;
        }
        if (node.getCategoricalSplit() == null) {
            StringBuilder sb2 = new StringBuilder();
            sb2.append("CASE WHEN ");
            sb2.append(strArr[node.getFeatureIndex()]);
            sb2.append(" <= ");
            sb2.append(printEightDecimal(node.getContinuousSplit()));
            sb2.append(" THEN ");
            appendNode(node.getNextNodes()[0], strArr, sb2);
            sb.append((CharSequence) sb2);
            StringBuilder sb3 = new StringBuilder();
            sb3.append(" WHEN ");
            sb3.append(strArr[node.getFeatureIndex()]);
            sb3.append(" > ");
            sb3.append(printEightDecimal(node.getContinuousSplit()));
            sb3.append(" THEN ");
            appendNode(node.getNextNodes()[1], strArr, sb3);
            sb.append((CharSequence) sb3);
            sb.append(" END");
            return;
        }
        boolean z = true;
        node.getCategoricalSplit();
        int length = node.getNextNodes().length;
        for (int i3 = 0; i3 < length; i3++) {
            StringBuilder sb4 = new StringBuilder();
            if (z) {
                sb4.append(" CASE WHEN ");
            } else {
                sb4.append(" WHEN ");
            }
            z = false;
            boolean z2 = true;
            for (int i4 = 0; i4 < node.getCategoricalSplit().length; i4++) {
                if (node.getCategoricalSplit()[i4] == i3) {
                    if (!z2) {
                        sb4.append(" or ");
                    }
                    sb4.append(strArr[node.getFeatureIndex()]);
                    sb4.append(" = ");
                    sb4.append(this.multiStringIndexerModelData.getToken(strArr[node.getFeatureIndex()], Long.valueOf(i4)));
                    z2 = false;
                }
            }
            sb4.append(" THEN ");
            appendNode(node.getNextNodes()[i3], strArr, sb4);
            sb.append((CharSequence) sb4);
        }
        sb.append(" END");
    }

    private static String printEightDecimal(double d) {
        return (d != Math.floor(d) || Double.isInfinite(d)) ? new BigDecimal(d).setScale(8, 4).doubleValue() == d ? String.valueOf(d) : String.format("%.8f", Double.valueOf(d)) : String.format("%d", Integer.valueOf((int) d));
    }
}
