package com.alibaba.alink.operator.common.tree.predictors;

import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.operator.common.tree.Node;
import java.util.Arrays;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/predictors/TreeModelEncoderModelMapper.class */
public class TreeModelEncoderModelMapper extends TreeModelMapper {
    private static final long serialVersionUID = 4543856042065853798L;
    private int dim;
    private NodeWithId[] roots;
    private transient ThreadLocal<Row> inputBufferThreadLocal;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/predictors/TreeModelEncoderModelMapper$NodeWithId.class */
    public static class NodeWithId {
        Node node;
        int id;
        NodeWithId[] next;

        private NodeWithId() {
        }

        public NodeWithId setNode(Node node) {
            this.node = node;
            return this;
        }

        public NodeWithId setId(int i) {
            this.id = i;
            return this;
        }

        public NodeWithId setNext(NodeWithId[] nodeWithIdArr) {
            this.next = nodeWithIdArr;
            return this;
        }
    }

    public TreeModelEncoderModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected TypeInformation<?> initPredResultColType(TableSchema tableSchema) {
        return AlinkTypes.SPARSE_VECTOR;
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        init(list);
        this.roots = new NodeWithId[this.treeModel.roots.length];
        this.dim = encode();
        this.inputBufferThreadLocal = ThreadLocal.withInitial(() -> {
            return new Row(((String[]) this.ioSchema.f0).length);
        });
    }

    private int encodeDFS(NodeWithId nodeWithId, int i) {
        if (nodeWithId.node.isLeaf()) {
            nodeWithId.setId(i);
            return i + 1;
        }
        int length = nodeWithId.node.getNextNodes().length;
        nodeWithId.setNext(new NodeWithId[length]);
        for (int i2 = 0; i2 < length; i2++) {
            nodeWithId.next[i2] = new NodeWithId().setNode(nodeWithId.node.getNextNodes()[i2]);
            i = encodeDFS(nodeWithId.next[i2], i);
        }
        return i;
    }

    private int encode() {
        int i = 0;
        for (int i2 = 0; i2 < this.roots.length; i2++) {
            this.roots[i2] = new NodeWithId().setNode(this.treeModel.roots[i2]);
            i = encodeDFS(this.roots[i2], i);
        }
        return i;
    }

    private int selectMaxWeightedCriteriaOfChild(Node node) {
        if (node.getMissingSplit() != null && node.getMissingSplit().length == 1) {
            return node.getMissingSplit()[0];
        }
        int i = 0;
        double d = 0.0d;
        int i2 = 0;
        for (Node node2 : node.getNextNodes()) {
            if (node2.getCounter() != null) {
                double weightSum = node2.getCounter().getWeightSum();
                if (weightSum > d) {
                    d = weightSum;
                    i = i2;
                }
            }
            i2++;
        }
        return i;
    }

    private int predictWithId(Row row, NodeWithId nodeWithId) {
        if (nodeWithId.node.isLeaf()) {
            return nodeWithId.id;
        }
        int featureIndex = nodeWithId.node.getFeatureIndex();
        int i = this.featuresIndex[featureIndex];
        if (i < 0) {
            throw new IllegalArgumentException("Can not find train column index: " + featureIndex);
        }
        Object field = row.getField(i);
        if (field == null) {
            return predictWithId(row, nodeWithId.next[selectMaxWeightedCriteriaOfChild(nodeWithId.node)]);
        }
        int[] categoricalSplit = nodeWithId.node.getCategoricalSplit();
        if (categoricalSplit == null) {
            return ((Double) field).doubleValue() <= nodeWithId.node.getContinuousSplit() ? predictWithId(row, nodeWithId.next[0]) : predictWithId(row, nodeWithId.next[1]);
        }
        int i2 = categoricalSplit[((Integer) field).intValue()];
        return i2 < 0 ? predictWithId(row, nodeWithId.next[selectMaxWeightedCriteriaOfChild(nodeWithId.node)]) : predictWithId(row, nodeWithId.next[i2]);
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Object predictResult(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        return predictResultDetail(slicedSelectedSample).f0;
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Tuple2<Object, String> predictResultDetail(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        Row row = this.inputBufferThreadLocal.get();
        slicedSelectedSample.fillRow(row);
        transform(row);
        int length = this.roots.length;
        SparseVector sparseVector = null;
        if (length > 0) {
            int[] iArr = new int[this.treeModel.roots.length];
            double[] dArr = new double[this.treeModel.roots.length];
            Arrays.fill(dArr, 1.0d);
            for (int i = 0; i < length; i++) {
                iArr[i] = predictWithId(row, this.roots[i]);
            }
            sparseVector = new SparseVector(this.dim, iArr, dArr);
        }
        return new Tuple2<>(sparseVector, (Object) null);
    }
}
