package com.alibaba.alink.operator.common.tree;

import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData;
import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelDataConverter;
import com.alibaba.alink.params.shared.colname.HasLabelCol;
import com.alibaba.alink.params.shared.colname.HasWeightColDefaultAsNull;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/TreeUtil.class */
public class TreeUtil {
    public static ParamInfo<TreeType> TREE_TYPE = ParamInfoFactory.createParamInfo("treeType", TreeType.class).setDescription("The criteria of the tree. There are three options: \"AVG\", \"partition\" or \"gini(infoGain, infoGainRatio, mse)\"").setHasDefaultValue(TreeType.AVG).build();

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/TreeUtil$TreeType.class */
    public enum TreeType {
        AVG,
        PARTITION,
        GINI,
        INFOGAIN,
        INFOGAINRATIO,
        MSE
    }

    public static Map<String, Integer> extractCategoricalColsSize(List<Row> list, String[] strArr) {
        return extractCategoricalColsSize(new MultiStringIndexerModelDataConverter().load(list), strArr);
    }

    public static Map<String, Integer> extractCategoricalColsSize(MultiStringIndexerModelData multiStringIndexerModelData, String[] strArr) {
        HashMap hashMap = new HashMap();
        for (String str : strArr) {
            hashMap.put(str, Integer.valueOf((int) multiStringIndexerModelData.getNumberOfTokensOfColumn(str)));
        }
        return hashMap;
    }

    public static FeatureMeta[] getFeatureMeta(String[] strArr, Map<String, Integer> map) {
        FeatureMeta[] featureMetaArr = new FeatureMeta[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            if (map.containsKey(strArr[i])) {
                featureMetaArr[i] = new FeatureMeta(strArr[i], i, map.get(strArr[i]).intValue());
            } else {
                featureMetaArr[i] = new FeatureMeta(strArr[i], i);
            }
        }
        return featureMetaArr;
    }

    public static FeatureMeta getLabelMeta(String str, int i, Map<String, Integer> map) {
        return map.containsKey(str) ? new FeatureMeta(str, i, map.get(str).intValue()) : new FeatureMeta(str, i);
    }

    public static String[] trainColNames(Params params, String[] strArr) {
        ArrayList arrayList = new ArrayList(Arrays.asList(strArr));
        if (params.contains(HasLabelCol.LABEL_COL)) {
            arrayList.add(params.get(HasLabelCol.LABEL_COL));
        }
        if (params.get(HasWeightColDefaultAsNull.WEIGHT_COL) != null) {
            arrayList.add(params.get(HasWeightColDefaultAsNull.WEIGHT_COL));
        }
        return (String[]) arrayList.toArray(new String[0]);
    }
}
