package com.alibaba.alink.operator.common.tree;

import com.alibaba.alink.common.exceptions.AkIllegalModelException;
import com.alibaba.alink.common.model.LabeledModelDataConverter;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.RowCollector;
import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter;
import com.alibaba.alink.params.classification.GbdtTrainParams;
import com.alibaba.alink.params.shared.tree.HasFeatureImportanceType;
import java.io.Serializable;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/TreeModelDataConverter.class */
public class TreeModelDataConverter extends LabeledModelDataConverter<TreeModelDataConverter, TreeModelDataConverter> implements Serializable {
    public static final ParamInfo<Partition> STRING_INDEXER_MODEL_PARTITION = ParamInfoFactory.createParamInfo("stringIndexerModelPartition", Partition.class).setDescription("stringIndexerModelPartition").setRequired().build();
    public static final ParamInfo<Partitions> TREE_PARTITIONS = ParamInfoFactory.createParamInfo("treePartition", Partitions.class).setDescription("treePartition").setRequired().build();
    public static final ParamInfo<String> IMPORTANCE_FIRST_COL = ParamInfoFactory.createParamInfo("importanceFirstCol", String.class).setHasDefaultValue("feature").build();
    public static final ParamInfo<String> IMPORTANCE_SECOND_COL = ParamInfoFactory.createParamInfo("importanceSecondCol", String.class).setHasDefaultValue("importance").build();
    private static final long serialVersionUID = 6997356679076377663L;
    public List<Row> stringIndexerModelSerialized;
    public Node[] roots;
    public Params meta;
    public Object[] labels;

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/TreeModelDataConverter$FeatureImportanceReducer.class */
    public static class FeatureImportanceReducer extends RichGroupReduceFunction<Row, Row> {
        private static final long serialVersionUID = -4934720404391098100L;

        public void reduce(Iterable<Row> iterable, Collector<Row> collector) throws Exception {
            ArrayList arrayList = new ArrayList();
            Iterator<Row> it = iterable.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next());
            }
            TreeModelDataConverter treeModelDataConverter = new TreeModelDataConverter();
            treeModelDataConverter.load(arrayList);
            HashMap hashMap = new HashMap();
            HasFeatureImportanceType.FeatureImportanceType featureImportanceType = (HasFeatureImportanceType.FeatureImportanceType) treeModelDataConverter.meta.get(GbdtTrainParams.FEATURE_IMPORTANCE_TYPE);
            ArrayDeque arrayDeque = new ArrayDeque();
            double d = 0.0d;
            for (Node node : treeModelDataConverter.roots) {
                arrayDeque.push(node);
                while (!arrayDeque.isEmpty()) {
                    Node node2 = (Node) arrayDeque.pop();
                    if (!node2.isLeaf()) {
                        switch (featureImportanceType) {
                            case GAIN:
                                hashMap.merge(Integer.valueOf(node2.getFeatureIndex()), Double.valueOf(node2.getGain()), (v0, v1) -> {
                                    return Double.sum(v0, v1);
                                });
                                d += node2.getGain();
                                break;
                            case COVER:
                                hashMap.merge(Integer.valueOf(node2.getFeatureIndex()), Double.valueOf(node2.getCounter().getNumInst()), (v0, v1) -> {
                                    return Double.sum(v0, v1);
                                });
                                break;
                            case WEIGHT:
                                hashMap.merge(Integer.valueOf(node2.getFeatureIndex()), Double.valueOf(1.0d), (v0, v1) -> {
                                    return Double.sum(v0, v1);
                                });
                                break;
                            default:
                                throw new IllegalArgumentException("Could not find the feature importace type. type: " + featureImportanceType);
                        }
                        for (Node node3 : node2.getNextNodes()) {
                            arrayDeque.push(node3);
                        }
                    }
                }
            }
            if (d > Criteria.INVALID_GAIN) {
                double d2 = d;
                Iterator it2 = hashMap.keySet().iterator();
                while (it2.hasNext()) {
                    hashMap.compute((Integer) it2.next(), (num, d3) -> {
                        return Double.valueOf(d3.doubleValue() / d2);
                    });
                }
            }
            if (treeModelDataConverter.meta == null || !treeModelDataConverter.meta.contains(GbdtTrainParams.FEATURE_COLS)) {
                for (Map.Entry entry : hashMap.entrySet()) {
                    Row row = new Row(2);
                    row.setField(0, String.valueOf(entry.getKey()));
                    row.setField(1, entry.getValue());
                    collector.collect(row);
                }
                return;
            }
            String[] strArr = (String[]) treeModelDataConverter.meta.get(GbdtTrainParams.FEATURE_COLS);
            for (int i = 0; i < strArr.length; i++) {
                Row row2 = new Row(2);
                row2.setField(0, strArr[i]);
                row2.setField(1, hashMap.getOrDefault(Integer.valueOf(i), Double.valueOf(Criteria.INVALID_GAIN)));
                collector.collect(row2);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/TreeModelDataConverter$NodeSerializable.class */
    public static class NodeSerializable implements Serializable {
        private static final long serialVersionUID = 3621266425332831183L;
        public Node node;
        public int id;
        public int[] nextIds;

        private NodeSerializable() {
        }

        public NodeSerializable setNode(Node node) {
            this.node = node;
            return this;
        }

        public NodeSerializable setId(int i) {
            this.id = i;
            return this;
        }

        public NodeSerializable setNextIds(int[] iArr) {
            this.nextIds = iArr;
            return this;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/TreeModelDataConverter$Partition.class */
    public static final class Partition implements Serializable {
        private static final long serialVersionUID = -646027497988196810L;
        private int f0;
        private int f1;

        public static Partition of(int i, int i2) {
            return new Partition().setF0(i).setF1(i2);
        }

        public int getF0() {
            return this.f0;
        }

        public Partition setF0(int i) {
            this.f0 = i;
            return this;
        }

        public int getF1() {
            return this.f1;
        }

        public Partition setF1(int i) {
            this.f1 = i;
            return this;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/TreeModelDataConverter$Partitions.class */
    public static final class Partitions implements Serializable {
        private static final long serialVersionUID = -6525488145026282786L;
        private List<Partition> partitions = new ArrayList();

        public List<Partition> getPartitions() {
            return this.partitions;
        }

        public Partitions setPartitions(List<Partition> list) {
            this.partitions = list;
            return this;
        }

        public Partitions add(Partition partition) {
            this.partitions.add(partition);
            return this;
        }
    }

    public TreeModelDataConverter() {
        this(null);
    }

    public TreeModelDataConverter(TypeInformation typeInformation) {
        super(typeInformation);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.model.LabeledModelDataConverter
    public Tuple3<Params, Iterable<String>, Iterable<Object>> serializeModel(TreeModelDataConverter treeModelDataConverter) {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        if (treeModelDataConverter.stringIndexerModelSerialized != null) {
            for (Row row : this.stringIndexerModelSerialized) {
                Object[] objArr = new Object[row.getArity()];
                for (int i2 = 0; i2 < row.getArity(); i2++) {
                    objArr[i2] = row.getField(i2);
                }
                arrayList.add(JsonConverter.toJson(objArr));
            }
            i = arrayList.size();
        }
        treeModelDataConverter.meta.set((ParamInfo<ParamInfo<Partition>>) STRING_INDEXER_MODEL_PARTITION, (ParamInfo<Partition>) Partition.of(0, i));
        Partitions partitions = new Partitions();
        for (Node node : treeModelDataConverter.roots) {
            List<String> serializeTree = serializeTree(node);
            int i3 = i;
            i = i3 + serializeTree.size();
            partitions.add(Partition.of(i3, i));
            arrayList.addAll(serializeTree);
        }
        treeModelDataConverter.meta.set((ParamInfo<ParamInfo<Partitions>>) TREE_PARTITIONS, (ParamInfo<Partitions>) partitions);
        return Tuple3.of(treeModelDataConverter.meta, arrayList, treeModelDataConverter.labels == null ? null : Arrays.asList(treeModelDataConverter.labels));
    }

    public static List<Row> saveModelWithData(List<Node> list, Params params, List<Row> list2, Object[] objArr) {
        TreeModelDataConverter treeModelDataConverter = new TreeModelDataConverter(FlinkTypeConverter.getFlinkType((String) params.get(ModelParamName.LABEL_TYPE_NAME)));
        treeModelDataConverter.meta = params;
        treeModelDataConverter.roots = (Node[]) list.toArray(new Node[0]);
        treeModelDataConverter.stringIndexerModelSerialized = list2;
        treeModelDataConverter.labels = objArr;
        RowCollector rowCollector = new RowCollector();
        treeModelDataConverter.save(treeModelDataConverter, rowCollector);
        return rowCollector.getRows();
    }

    public static List<String> serializeTree(Node node) {
        ArrayList arrayList = new ArrayList();
        ArrayDeque arrayDeque = new ArrayDeque();
        int i = 0 + 1;
        arrayDeque.addFirst(new NodeSerializable().setNode(node).setId(0));
        while (!arrayDeque.isEmpty()) {
            NodeSerializable nodeSerializable = (NodeSerializable) arrayDeque.pollLast();
            if (!nodeSerializable.node.isLeaf()) {
                int[] iArr = new int[nodeSerializable.node.getNextNodes().length];
                for (int i2 = 0; i2 < nodeSerializable.node.getNextNodes().length; i2++) {
                    iArr[i2] = i;
                    int i3 = i;
                    i++;
                    arrayDeque.addFirst(new NodeSerializable().setNode(nodeSerializable.node.getNextNodes()[i2]).setId(i3));
                }
                nodeSerializable.setNextIds(iArr);
            }
            arrayList.add(JsonConverter.toJson(nodeSerializable));
        }
        return arrayList;
    }

    public static Node deserializeTree(List<String> list) {
        int size = list.size();
        NodeSerializable[] nodeSerializableArr = new NodeSerializable[size];
        int i = 0;
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            NodeSerializable nodeSerializable = (NodeSerializable) JsonConverter.fromJson(it.next(), NodeSerializable.class);
            if (nodeSerializable.id < 0 || nodeSerializable.id >= size) {
                throw new AkIllegalModelException("Model is broken. node index: " + nodeSerializable.id);
            }
            nodeSerializableArr[nodeSerializable.id] = nodeSerializable;
            i++;
        }
        for (int i2 = 0; i2 < i; i2++) {
            if (nodeSerializableArr[i2] == null) {
                throw new AkIllegalModelException("Model is broken. index: " + i2);
            }
            int[] iArr = nodeSerializableArr[i2].nextIds;
            if (iArr != null) {
                int length = iArr.length;
                Node[] nodeArr = new Node[length];
                for (int i3 = 0; i3 < length; i3++) {
                    nodeArr[i3] = nodeSerializableArr[iArr[i3]].node;
                }
                nodeSerializableArr[i2].node.setNextNodes(nodeArr);
            }
        }
        if (size != 0) {
            return nodeSerializableArr[0].node;
        }
        return null;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.common.model.LabeledModelDataConverter
    protected TreeModelDataConverter deserializeModel(Params params, Iterable<String> iterable, Iterable<Object> iterable2) {
        Partition partition = (Partition) params.get(STRING_INDEXER_MODEL_PARTITION);
        ArrayList arrayList = new ArrayList();
        arrayList.getClass();
        iterable.forEach((v1) -> {
            r1.add(v1);
        });
        if (partition.getF1() != partition.getF0()) {
            this.stringIndexerModelSerialized = new ArrayList();
            for (int f0 = partition.getF0(); f0 < partition.getF1(); f0++) {
                Object[] objArr = (Object[]) JsonConverter.fromJson((String) arrayList.get(f0), Object[].class);
                this.stringIndexerModelSerialized.add(Row.of(new Object[]{Long.valueOf(((Integer) objArr[0]).longValue()), objArr[1], objArr[2]}));
            }
        } else {
            this.stringIndexerModelSerialized = null;
        }
        this.roots = (Node[]) ((Partitions) params.get(TREE_PARTITIONS)).getPartitions().stream().map(partition2 -> {
            return deserializeTree(arrayList.subList(partition2.getF0(), partition2.getF1()));
        }).toArray(i -> {
            return new Node[i];
        });
        this.meta = params;
        ArrayList arrayList2 = new ArrayList();
        arrayList2.getClass();
        iterable2.forEach(arrayList2::add);
        this.labels = arrayList2.toArray();
        return this;
    }

    @Override // com.alibaba.alink.common.model.LabeledModelDataConverter
    protected /* bridge */ /* synthetic */ TreeModelDataConverter deserializeModel(Params params, Iterable iterable, Iterable iterable2) {
        return deserializeModel(params, (Iterable<String>) iterable, (Iterable<Object>) iterable2);
    }
}
