package com.alibaba.alink.operator.common.tree.predictors;

import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.common.tree.LabelCounter;
import com.alibaba.alink.operator.common.tree.Node;
import com.alibaba.alink.operator.common.tree.Preprocessing;
import com.alibaba.alink.operator.common.tree.parallelcart.BaseGbdtTrainBatchOp;
import com.alibaba.alink.operator.common.tree.parallelcart.SaveModel;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.LossType;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.LossUtils;
import com.alibaba.alink.params.classification.GbdtTrainParams;
import com.alibaba.alink.params.shared.colname.HasVectorColDefaultAsNull;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/predictors/GbdtModelMapper.class */
public class GbdtModelMapper extends TreeModelMapper {
    private static final long serialVersionUID = 75264909895533116L;
    private double period;
    private boolean isClassification;
    private int vectorColIndex;
    private transient ThreadLocal<Row> inputBufferThreadLocal;

    public GbdtModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.vectorColIndex = -1;
        if (params.contains(GbdtTrainParams.VECTOR_COL)) {
            this.vectorColIndex = TableUtil.findColIndexWithAssertAndHint(tableSchema2, (String) params.get(HasVectorColDefaultAsNull.VECTOR_COL));
        }
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        init(list);
        this.period = ((Double) this.treeModel.meta.get(SaveModel.GBDT_Y_PERIOD)).doubleValue();
        if (this.treeModel.meta.contains(LossUtils.LOSS_TYPE)) {
            this.isClassification = LossUtils.isClassification((LossType) this.treeModel.meta.get(LossUtils.LOSS_TYPE));
        } else {
            this.isClassification = ((Integer) this.treeModel.meta.get(BaseGbdtTrainBatchOp.ALGO_TYPE)).intValue() == 1;
        }
        if (this.vectorColIndex < 0) {
            this.inputBufferThreadLocal = ThreadLocal.withInitial(() -> {
                return new Row(((String[]) this.ioSchema.f0).length);
            });
        }
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Object predictResult(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        return predictResultDetail(slicedSelectedSample).f0;
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Tuple2<Object, String> predictResultDetail(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        Tuple2<Object, Map<String, Double>> predictResultDetailTable;
        if (this.vectorColIndex >= 0) {
            predictResultDetailTable = predictResultDetailVector(this.treeModel.roots, VectorUtil.getVector(slicedSelectedSample.get(this.vectorColIndex)));
        } else {
            Row row = this.inputBufferThreadLocal.get();
            slicedSelectedSample.fillRow(row);
            predictResultDetailTable = predictResultDetailTable(this.treeModel.roots, row);
        }
        return new Tuple2<>(predictResultDetailTable.f0, predictResultDetailTable.f1 == null ? null : JsonConverter.toJson(predictResultDetailTable.f1));
    }

    private void predictVector(Vector vector, Node node, LabelCounter labelCounter, double d) {
        if (node.isLeaf()) {
            labelCounter.add(node.getCounter(), d);
            return;
        }
        double d2 = vector.get(node.getFeatureIndex());
        if (Preprocessing.isMissing(d2, this.zeroAsMissing)) {
            if (node.getMissingSplit() == null || node.getMissingSplit().length != 1) {
                throw new IllegalArgumentException("When the value is missing, there must be missing split.");
            }
            predictVector(vector, node.getNextNodes()[node.getMissingSplit()[0]], labelCounter, d);
            return;
        }
        if (node.getCategoricalSplit() != null) {
            throw new IllegalStateException("Unsupported categorical feature now.");
        }
        if (d2 <= node.getContinuousSplit()) {
            predictVector(vector, node.getNextNodes()[0], labelCounter, d);
        } else {
            predictVector(vector, node.getNextNodes()[1], labelCounter, d);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v16, types: [java.lang.Object[]] */
    /* JADX WARN: Type inference failed for: r0v17 */
    /* JADX WARN: Type inference failed for: r0v25, types: [java.lang.Object[]] */
    /* JADX WARN: Type inference failed for: r0v26 */
    private Tuple2<Object, Map<String, Double>> predictResultDetailWithLabelCounter(LabelCounter labelCounter) {
        Double valueOf;
        HashMap hashMap = null;
        if (this.isClassification) {
            double exp = 1.0d / (1.0d + Math.exp(-(labelCounter.getDistributions()[0] + this.period)));
            valueOf = exp >= 0.5d ? this.treeModel.labels[1] : this.treeModel.labels[0];
            hashMap = new HashMap();
            hashMap.put(this.treeModel.labels[0].toString(), Double.valueOf(1.0d - exp));
            hashMap.put(this.treeModel.labels[1].toString(), Double.valueOf(exp));
        } else {
            valueOf = Double.valueOf(labelCounter.getDistributions()[0] + this.period);
        }
        return Tuple2.of(valueOf, hashMap);
    }

    private Tuple2<Object, Map<String, Double>> predictResultDetailVector(Node[] nodeArr, Vector vector) {
        int length = nodeArr.length;
        if (length <= 0) {
            return Tuple2.of((Object) null, (Object) null);
        }
        LabelCounter labelCounter = new LabelCounter(Criteria.INVALID_GAIN, 0, new double[nodeArr[0].getCounter().getDistributions().length]);
        predictVector(vector, nodeArr[0], labelCounter, 1.0d);
        for (int i = 1; i < length; i++) {
            if (nodeArr[i] != null) {
                predictVector(vector, nodeArr[i], labelCounter, 1.0d);
            }
        }
        return predictResultDetailWithLabelCounter(labelCounter);
    }

    private Tuple2<Object, Map<String, Double>> predictResultDetailTable(Node[] nodeArr, Row row) throws Exception {
        transform(row);
        int length = nodeArr.length;
        if (length <= 0) {
            return Tuple2.of((Object) null, (Object) null);
        }
        LabelCounter labelCounter = new LabelCounter(Criteria.INVALID_GAIN, 0, new double[nodeArr[0].getCounter().getDistributions().length]);
        predict(row, nodeArr[0], labelCounter, 1.0d);
        for (int i = 1; i < length; i++) {
            if (nodeArr[i] != null) {
                predict(row, nodeArr[i], labelCounter, 1.0d);
            }
        }
        return predictResultDetailWithLabelCounter(labelCounter);
    }
}
