package com.alibaba.alink.operator.common.tree.predictors;

import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.common.tree.LabelCounter;
import com.alibaba.alink.operator.common.tree.Node;
import com.alibaba.alink.operator.common.tree.TreeUtil;
import java.util.HashMap;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/predictors/RandomForestModelMapper.class */
public class RandomForestModelMapper extends TreeModelMapper {
    private static final Logger LOG = LoggerFactory.getLogger(RandomForestModelMapper.class);
    private static final long serialVersionUID = 1392112308487523143L;
    private transient ThreadLocal<Row> inputBufferThreadLocal;

    public RandomForestModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        init(list);
        this.inputBufferThreadLocal = ThreadLocal.withInitial(() -> {
            return new Row(((String[]) this.ioSchema.f0).length);
        });
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Tuple2<Object, String> predictResultDetail(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        Node[] nodeArr = this.treeModel.roots;
        Row row = this.inputBufferThreadLocal.get();
        slicedSelectedSample.fillRow(row);
        transform(row);
        int length = nodeArr.length;
        Object obj = null;
        HashMap hashMap = null;
        if (length > 0) {
            LabelCounter labelCounter = new LabelCounter(Criteria.INVALID_GAIN, 0, new double[nodeArr[0].getCounter().getDistributions().length]);
            predict(row, nodeArr[0], labelCounter, 1.0d);
            for (int i = 1; i < length; i++) {
                predict(row, nodeArr[i], labelCounter, 1.0d);
            }
            labelCounter.normWithWeight();
            if (Criteria.isRegression((TreeUtil.TreeType) this.treeModel.meta.get(TreeUtil.TREE_TYPE))) {
                obj = Double.valueOf(labelCounter.getDistributions()[0]);
            } else {
                hashMap = new HashMap();
                double[] distributions = labelCounter.getDistributions();
                double d = 0.0d;
                int i2 = -1;
                for (int i3 = 0; i3 < distributions.length; i3++) {
                    hashMap.put(String.valueOf(this.treeModel.labels[i3]), Double.valueOf(distributions[i3]));
                    if (d < distributions[i3]) {
                        d = distributions[i3];
                        i2 = i3;
                    }
                }
                if (i2 == -1) {
                    LOG.warn("Can not find the probability: {}", JsonConverter.toJson(distributions));
                }
                obj = this.treeModel.labels[i2];
            }
        }
        return new Tuple2<>(obj, hashMap == null ? null : JsonConverter.toJson(hashMap));
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Object predictResult(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        return predictResultDetail(slicedSelectedSample).f0;
    }
}
