package com.alibaba.alink.operator.common.tree.paralleltree;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelDataConverter;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.common.tree.FeatureMeta;
import com.alibaba.alink.operator.common.tree.TreeUtil;
import com.alibaba.alink.operator.common.tree.parallelcart.InitTreeObjs;
import com.alibaba.alink.params.classification.RandomForestTrainParams;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/paralleltree/TreeInitObj.class */
public class TreeInitObj extends ComputeFunction {
    private static final long serialVersionUID = 1809146149000002401L;
    private Params params;

    public TreeInitObj(Params params) {
        this.params = params;
    }

    private static QuantileDiscretizerModelDataConverter initialMapping(List<Row> list) {
        if (list.isEmpty()) {
            return null;
        }
        QuantileDiscretizerModelDataConverter quantileDiscretizerModelDataConverter = new QuantileDiscretizerModelDataConverter();
        quantileDiscretizerModelDataConverter.load(list);
        return quantileDiscretizerModelDataConverter;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        if (comContext.getStepNo() != 1) {
            return;
        }
        List list = (List) comContext.getObj("treeInput");
        List list2 = (List) comContext.getObj(InitTreeObjs.QUANTILE_MODEL);
        List list3 = (List) comContext.getObj("stringIndexerModel");
        List list4 = (List) comContext.getObj("labels");
        int size = list == null ? 0 : list.size();
        Params m1495clone = this.params.m1495clone();
        m1495clone.set((ParamInfo<ParamInfo<Integer>>) TreeObj.TASK_ID, (ParamInfo<Integer>) Integer.valueOf(comContext.getTaskId()));
        m1495clone.set((ParamInfo<ParamInfo<Integer>>) TreeObj.NUM_OF_SUBTASKS, (ParamInfo<Integer>) Integer.valueOf(comContext.getNumTask()));
        m1495clone.set((ParamInfo<ParamInfo<Integer>>) TreeObj.N_LOCAL_ROW, (ParamInfo<Integer>) Integer.valueOf(size));
        QuantileDiscretizerModelDataConverter initialMapping = initialMapping(list2);
        ArrayList arrayList = new ArrayList();
        if (this.params.get(RandomForestTrainParams.CATEGORICAL_COLS) != null) {
            arrayList.addAll(Arrays.asList((Object[]) this.params.get(RandomForestTrainParams.CATEGORICAL_COLS)));
        }
        Map<String, Integer> extractCategoricalColsSize = TreeUtil.extractCategoricalColsSize((List<Row>) list3, (String[]) arrayList.toArray(new String[0]));
        if (!Criteria.isRegression((TreeUtil.TreeType) this.params.get(TreeUtil.TREE_TYPE))) {
            extractCategoricalColsSize.put(this.params.get(RandomForestTrainParams.LABEL_COL), Integer.valueOf(((Object[]) list4.get(0)).length));
        }
        FeatureMeta[] featureMeta = TreeUtil.getFeatureMeta((String[]) this.params.get(RandomForestTrainParams.FEATURE_COLS), extractCategoricalColsSize);
        FeatureMeta labelMeta = TreeUtil.getLabelMeta((String) this.params.get(RandomForestTrainParams.LABEL_COL), ((String[]) this.params.get(RandomForestTrainParams.FEATURE_COLS)).length, extractCategoricalColsSize);
        TreeObj regObj = Criteria.isRegression((TreeUtil.TreeType) this.params.get(TreeUtil.TREE_TYPE)) ? new RegObj(m1495clone, initialMapping, featureMeta, labelMeta) : new ClassifierObj(m1495clone, initialMapping, featureMeta, labelMeta);
        int length = ((String[]) m1495clone.get(RandomForestTrainParams.FEATURE_COLS)).length;
        int[] iArr = new int[length * size];
        double[] dArr = null;
        int[] iArr2 = null;
        if (Criteria.isRegression((TreeUtil.TreeType) this.params.get(TreeUtil.TREE_TYPE))) {
            dArr = new double[size];
        } else {
            iArr2 = new int[size];
        }
        int i = 0;
        for (int i2 = 0; i2 < size; i2++) {
            for (int i3 = 0; i3 < length; i3++) {
                iArr[(i3 * size) + i] = ((Integer) ((Row) list.get(i2)).getField(i3)).intValue();
            }
            if (Criteria.isRegression((TreeUtil.TreeType) this.params.get(TreeUtil.TREE_TYPE))) {
                dArr[i] = ((Double) ((Row) list.get(i2)).getField(length)).doubleValue();
            } else {
                iArr2[i] = ((Integer) ((Row) list.get(i2)).getField(length)).intValue();
            }
            i++;
        }
        regObj.setFeatures(iArr);
        if (Criteria.isRegression((TreeUtil.TreeType) this.params.get(TreeUtil.TREE_TYPE))) {
            regObj.setLabels(dArr);
        } else {
            regObj.setLabels(iArr2);
        }
        double[] dArr2 = new double[regObj.getMaxHistBufferSize()];
        comContext.putObj("allReduce", dArr2);
        regObj.setHist(dArr2);
        regObj.initialRoot();
        comContext.putObj("treeObj", regObj);
    }
}
