package com.alibaba.alink.operator.common.tree;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.CompareCriterionFunction;
import com.alibaba.alink.common.comqueue.CompleteResultFunction;
import com.alibaba.alink.common.comqueue.IterativeComQueue;
import com.alibaba.alink.common.comqueue.communication.AllReduce;
import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter;
import com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.common.tree.TreeModelDataConverter;
import com.alibaba.alink.operator.common.tree.TreeUtil;
import com.alibaba.alink.operator.common.tree.parallelcart.InitTreeObjs;
import com.alibaba.alink.operator.common.tree.paralleltree.TreeInitObj;
import com.alibaba.alink.operator.common.tree.paralleltree.TreeObj;
import com.alibaba.alink.operator.common.tree.paralleltree.TreeSplit;
import com.alibaba.alink.operator.common.tree.paralleltree.TreeStat;
import com.alibaba.alink.operator.common.tree.seriestree.DecisionTree;
import com.alibaba.alink.operator.common.tree.seriestree.DenseData;
import com.alibaba.alink.params.classification.RandomForestTrainParams;
import com.alibaba.alink.params.shared.tree.HasSeed;
import com.alibaba.alink.params.shared.tree.HasTreePartition;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.ExecutorUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL), @PortSpec(value = PortType.DATA, desc = PortDesc.FEATURE_IMPORTANCE)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "featureCols", allowedTypeCollections = {TypeCollections.TREE_FEATURE_TYPES}), @ParamSelectColumnSpec(name = "categoricalCols", allowedTypeCollections = {TypeCollections.TREE_FEATURE_TYPES}), @ParamSelectColumnSpec(name = "weightCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "labelCol")})
/* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseRandomForestTrainBatchOp.class */
public abstract class BaseRandomForestTrainBatchOp<T extends BaseRandomForestTrainBatchOp<T>> extends BatchOperator<T> {
    private static final long serialVersionUID = 5757403088524138175L;
    protected DataSet<Object[]> labels;
    protected BatchOperator<?> stringIndexerModel;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp$5, reason: invalid class name */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseRandomForestTrainBatchOp$5.class */
    public static /* synthetic */ class AnonymousClass5 {
        static final /* synthetic */ int[] $SwitchMap$com$alibaba$alink$operator$common$tree$TreeUtil$TreeType = new int[TreeUtil.TreeType.values().length];

        static {
            try {
                $SwitchMap$com$alibaba$alink$operator$common$tree$TreeUtil$TreeType[TreeUtil.TreeType.AVG.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$alibaba$alink$operator$common$tree$TreeUtil$TreeType[TreeUtil.TreeType.PARTITION.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$alibaba$alink$operator$common$tree$TreeUtil$TreeType[TreeUtil.TreeType.MSE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$com$alibaba$alink$operator$common$tree$TreeUtil$TreeType[TreeUtil.TreeType.GINI.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$com$alibaba$alink$operator$common$tree$TreeUtil$TreeType[TreeUtil.TreeType.INFOGAIN.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$com$alibaba$alink$operator$common$tree$TreeUtil$TreeType[TreeUtil.TreeType.INFOGAINRATIO.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseRandomForestTrainBatchOp$AvgPartition.class */
    public static class AvgPartition implements Partitioner<Integer> {
        private static final long serialVersionUID = -8338959787279940010L;

        public int partition(Integer num, int i) {
            return num.intValue() % i;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseRandomForestTrainBatchOp$CheckNullValue.class */
    public static class CheckNullValue implements MapFunction<Row, Row> {
        private static final long serialVersionUID = -2809221584231401798L;
        private String[] cols;

        public CheckNullValue(String[] strArr) {
            this.cols = strArr;
        }

        public Row map(Row row) throws Exception {
            for (int i = 0; i < row.getArity(); i++) {
                if (row.getField(i) == null) {
                    throw new AkIllegalOperatorParameterException("There should not be null value in training dataset. col: " + this.cols[i] + ", Maybe you can use {@code Imputer} to fill the missing values");
                }
            }
            return row;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseRandomForestTrainBatchOp$Criterion.class */
    public static class Criterion extends CompareCriterionFunction {
        private static final long serialVersionUID = -8249556754233088562L;

        private Criterion() {
        }

        @Override // com.alibaba.alink.common.comqueue.CompareCriterionFunction
        public boolean calc(ComContext comContext) {
            return ((TreeObj) comContext.getObj("treeObj")).terminationCriterion();
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseRandomForestTrainBatchOp$SampleDataLimit.class */
    public static class SampleDataLimit extends RichMapPartitionFunction<Row, Tuple2<Integer, Row>> {
        private static final long serialVersionUID = -8114271430777216933L;
        private final long seed;
        private double factor;
        private final int treeNum;

        public SampleDataLimit(long j, double d, int i) {
            this.seed = j;
            this.factor = d;
            this.treeNum = i;
        }

        public void open(Configuration configuration) throws Exception {
            if (this.factor > 1.0d) {
                this.factor = Math.min(this.factor / ((Double) getRuntimeContext().getBroadcastVariableWithInitializer("totalCnt", new BroadcastVariableInitializer<Long, Double>() { // from class: com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp.SampleDataLimit.1
                    public Double initializeBroadcastVariable(Iterable<Long> iterable) {
                        Iterator<Long> it = iterable.iterator();
                        if (it.hasNext()) {
                            return Double.valueOf(it.next().doubleValue());
                        }
                        throw new AkIllegalOperatorParameterException("Can not find total sample count of sample in training dataset if factor > 1.0");
                    }

                    /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                    public /* bridge */ /* synthetic */ Object m595initializeBroadcastVariable(Iterable iterable) {
                        return initializeBroadcastVariable((Iterable<Long>) iterable);
                    }
                })).doubleValue(), 1.0d);
            }
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Integer, Row>> collector) {
            int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
            int superstepNumber = getIterationRuntimeContext().getSuperstepNumber();
            Random random = new Random(this.seed + (superstepNumber * (getRuntimeContext().getIndexOfThisSubtask() + 1)));
            for (Row row : iterable) {
                for (int i = (superstepNumber - 1) * numberOfParallelSubtasks; i < this.treeNum && i < superstepNumber * numberOfParallelSubtasks; i++) {
                    if (random.nextDouble() < this.factor) {
                        collector.collect(new Tuple2(Integer.valueOf(i), row));
                    }
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseRandomForestTrainBatchOp$SerializeModel.class */
    public static class SerializeModel extends RichGroupReduceFunction<Tuple2<Integer, String>, Row> {
        private static final long serialVersionUID = -314826879276130037L;
        private Params params;
        private transient List<Row> stringIndexerModelSerialized;
        private transient Object[] labels;

        public SerializeModel(Params params) {
            this.params = params;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            this.stringIndexerModelSerialized = getRuntimeContext().getBroadcastVariable("stringIndexerModel");
            this.labels = (Object[]) getRuntimeContext().getBroadcastVariableWithInitializer("labels", new BroadcastVariableInitializer<Object[], Object[]>() { // from class: com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp.SerializeModel.1
                public Object[] initializeBroadcastVariable(Iterable<Object[]> iterable) {
                    Iterator<Object[]> it = iterable.iterator();
                    if (it.hasNext()) {
                        return it.next();
                    }
                    return null;
                }

                /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                public /* bridge */ /* synthetic */ Object m596initializeBroadcastVariable(Iterable iterable) {
                    return initializeBroadcastVariable((Iterable<Object[]>) iterable);
                }
            });
        }

        public void reduce(Iterable<Tuple2<Integer, String>> iterable, Collector<Row> collector) throws Exception {
            List<Row> saveModelWithData = TreeModelDataConverter.saveModelWithData((List) ((Map) StreamSupport.stream(iterable.spliterator(), false).collect(Collectors.groupingBy(tuple2 -> {
                return (Integer) tuple2.f0;
            }, Collectors.mapping(tuple22 -> {
                return (String) tuple22.f1;
            }, Collectors.toList())))).entrySet().stream().sorted((entry, entry2) -> {
                return ((Integer) entry.getKey()).compareTo((Integer) entry2.getKey());
            }).map(entry3 -> {
                return TreeModelDataConverter.deserializeTree((List) entry3.getValue());
            }).collect(Collectors.toList()), this.params, this.stringIndexerModelSerialized, this.labels);
            collector.getClass();
            saveModelWithData.forEach((v1) -> {
                r1.collect(v1);
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseRandomForestTrainBatchOp$SerializeModelCompleteResultFunction.class */
    public static class SerializeModelCompleteResultFunction extends CompleteResultFunction {
        private static final long serialVersionUID = -774526299754876291L;
        private final Params meta;

        SerializeModelCompleteResultFunction(Params params) {
            this.meta = params;
        }

        @Override // com.alibaba.alink.common.comqueue.CompleteResultFunction
        public List<Row> calc(ComContext comContext) {
            if (comContext.getTaskId() != 0) {
                return null;
            }
            TreeObj treeObj = (TreeObj) comContext.getObj("treeObj");
            List list = (List) comContext.getObj("stringIndexerModel");
            List list2 = (List) comContext.getObj("labels");
            return TreeModelDataConverter.saveModelWithData(treeObj.getRoots(), this.meta, list, (list2 == null || list2.isEmpty()) ? null : (Object[]) list2.get(0));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseRandomForestTrainBatchOp$SeriesTrainFunction.class */
    public static class SeriesTrainFunction extends RichMapPartitionFunction<Tuple2<Integer, Row>, Tuple2<Integer, String>> {
        private static final Logger LOG = LoggerFactory.getLogger(SeriesTrainFunction.class);
        private static final long serialVersionUID = 8664682425332787016L;
        private static final int NUM_THREADS_POOL = 4;
        private Map<String, Integer> categoricalColsSize;
        private final Params params;
        private final String[] featureCols;
        private transient int cnt;
        private transient List<Tuple2<Integer, Node>> trees;
        private transient DenseData data;
        private transient ExecutorService executorService;

        /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseRandomForestTrainBatchOp$SeriesTrainFunction$IterableWrapper.class */
        private static class IterableWrapper implements Iterable<Row> {
            private final Iterable<Tuple2<Integer, Row>> iterable;
            private transient IteratorWrapper iterator;

            public IterableWrapper(Iterable<Tuple2<Integer, Row>> iterable) {
                this.iterable = iterable;
            }

            public int getTreeId() {
                return this.iterator.getTreeId();
            }

            @Override // java.lang.Iterable
            public Iterator<Row> iterator() {
                if (this.iterator == null) {
                    this.iterator = new IteratorWrapper(this.iterable.iterator());
                }
                return this.iterator;
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseRandomForestTrainBatchOp$SeriesTrainFunction$IteratorWrapper.class */
        public static class IteratorWrapper implements Iterator<Row> {
            private final Iterator<Tuple2<Integer, Row>> iterator;
            private int treeId = -1;

            public IteratorWrapper(Iterator<Tuple2<Integer, Row>> it) {
                this.iterator = it;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.iterator.hasNext();
            }

            public int getTreeId() {
                return this.treeId;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Row next() {
                Tuple2<Integer, Row> next = this.iterator.next();
                this.treeId = ((Integer) next.f0).intValue();
                return (Row) next.f1;
            }
        }

        public SeriesTrainFunction(Params params, String[] strArr) {
            this.params = params;
            this.featureCols = strArr;
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void open(Configuration configuration) throws Exception {
            this.categoricalColsSize = (Map) getRuntimeContext().getBroadcastVariableWithInitializer("stringIndexerModel", new BroadcastVariableInitializer<Row, Map<String, Integer>>() { // from class: com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp.SeriesTrainFunction.1
                public Map<String, Integer> initializeBroadcastVariable(Iterable<Row> iterable) {
                    ArrayList arrayList = new ArrayList();
                    Iterator<Row> it = iterable.iterator();
                    while (it.hasNext()) {
                        arrayList.add(it.next());
                    }
                    return TreeUtil.extractCategoricalColsSize(arrayList, (String[]) SeriesTrainFunction.this.params.get(RandomForestTrainParams.CATEGORICAL_COLS));
                }

                /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                public /* bridge */ /* synthetic */ Object m598initializeBroadcastVariable(Iterable iterable) {
                    return initializeBroadcastVariable((Iterable<Row>) iterable);
                }
            });
            if (!Criteria.isRegression((TreeUtil.TreeType) this.params.get(TreeUtil.TREE_TYPE))) {
                this.categoricalColsSize.put(this.params.get(RandomForestTrainParams.LABEL_COL), getRuntimeContext().getBroadcastVariableWithInitializer("labelSize", new BroadcastVariableInitializer<Integer, Integer>() { // from class: com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp.SeriesTrainFunction.2
                    public Integer initializeBroadcastVariable(Iterable<Integer> iterable) {
                        return iterable.iterator().next();
                    }

                    /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                    public /* bridge */ /* synthetic */ Object m599initializeBroadcastVariable(Iterable iterable) {
                        return initializeBroadcastVariable((Iterable<Integer>) iterable);
                    }
                }));
            }
            long longValue = ((Long) getRuntimeContext().getBroadcastVariableWithInitializer("totalCnt", new BroadcastVariableInitializer<Long, Long>() { // from class: com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp.SeriesTrainFunction.3
                public Long initializeBroadcastVariable(Iterable<Long> iterable) {
                    Iterator<Long> it = iterable.iterator();
                    if (it.hasNext()) {
                        return it.next();
                    }
                    throw new AkIllegalOperatorParameterException("Can not find total sample count of sample in training dataset if factor > 1.0");
                }

                /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                public /* bridge */ /* synthetic */ Object m600initializeBroadcastVariable(Iterable iterable) {
                    return initializeBroadcastVariable((Iterable<Long>) iterable);
                }
            })).longValue();
            this.cnt = ((Double) this.params.get(RandomForestTrainParams.SUBSAMPLING_RATIO)).doubleValue() > 1.0d ? Double.valueOf(Math.min(longValue, ((Double) this.params.get(RandomForestTrainParams.SUBSAMPLING_RATIO)).doubleValue())).intValue() : Double.valueOf(((Double) this.params.get(RandomForestTrainParams.SUBSAMPLING_RATIO)).doubleValue() * longValue).intValue();
            this.executorService = new ThreadPoolExecutor(NUM_THREADS_POOL, NUM_THREADS_POOL, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue(NUM_THREADS_POOL), new BasicThreadFactory.Builder().namingPattern("random-forest-%d").daemon(true).build(), new ThreadPoolExecutor.AbortPolicy());
        }

        public void close() throws Exception {
            if (this.executorService != null) {
                ExecutorUtils.gracefulShutdown(5L, TimeUnit.SECONDS, new ExecutorService[]{this.executorService});
            }
        }

        public void mapPartition(Iterable<Tuple2<Integer, Row>> iterable, Collector<Tuple2<Integer, String>> collector) {
            LOG.info("start the random forests training");
            int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
            int superstepNumber = getIterationRuntimeContext().getSuperstepNumber();
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            Params m1495clone = this.params.m1495clone();
            if (this.trees == null) {
                this.trees = new ArrayList();
            }
            if (this.data == null) {
                this.data = new DenseData(this.cnt, TreeUtil.getFeatureMeta(this.featureCols, this.categoricalColsSize), TreeUtil.getLabelMeta((String) m1495clone.get(RandomForestTrainParams.LABEL_COL), this.featureCols.length, this.categoricalColsSize));
            }
            IterableWrapper iterableWrapper = new IterableWrapper(iterable);
            this.data.readFromInstances(iterableWrapper);
            if (iterableWrapper.getTreeId() >= 0) {
                LOG.info("start the random forests training {}", Integer.valueOf(iterableWrapper.getTreeId()));
                m1495clone.set((ParamInfo<ParamInfo<Criteria.Gain>>) Criteria.Gain.GAIN, (ParamInfo<Criteria.Gain>) BaseRandomForestTrainBatchOp.getGainFromParams(m1495clone, iterableWrapper.getTreeId()));
                m1495clone.set((ParamInfo<ParamInfo<Long>>) HasSeed.SEED, (ParamInfo<Long>) Long.valueOf(((Long) m1495clone.get(HasSeed.SEED)).longValue() + (superstepNumber * (indexOfThisSubtask + 1))));
                this.trees.add(Tuple2.of(Integer.valueOf(iterableWrapper.getTreeId()), new DecisionTree(this.data, m1495clone, this.executorService).fit()));
                LOG.info("end the random forests training {}", Integer.valueOf(iterableWrapper.getTreeId()));
            }
            if (superstepNumber * numberOfParallelSubtasks >= ((Integer) m1495clone.get(RandomForestTrainParams.NUM_TREES)).intValue()) {
                for (Tuple2<Integer, Node> tuple2 : this.trees) {
                    Iterator<String> it = TreeModelDataConverter.serializeTree((Node) tuple2.f1).iterator();
                    while (it.hasNext()) {
                        collector.collect(Tuple2.of(tuple2.f0, it.next()));
                    }
                }
            } else {
                collector.collect(Tuple2.of(-1, ""));
            }
            LOG.info("end the random forests training");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseRandomForestTrainBatchOp(Params params) {
        super(params);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public T linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        rewriteTreeType(getParams());
        rewriteLabelType(checkAndGetFirst.getSchema(), getParams());
        String[] optionalFeatureCols = TableUtil.getOptionalFeatureCols(checkAndGetFirst.getSchema(), getParams());
        getParams().set((ParamInfo<ParamInfo<String[]>>) ModelParamName.FEATURE_TYPES, (ParamInfo<String[]>) FlinkTypeConverter.getTypeString(TableUtil.findColTypesWithAssertAndHint(checkAndGetFirst.getSchema(), optionalFeatureCols)));
        BatchOperator<?> select = Preprocessing.select(checkAndGetFirst, TreeUtil.trainColNames(getParams(), optionalFeatureCols));
        set(RandomForestTrainParams.CATEGORICAL_COLS, TableUtil.getCategoricalCols(select.getSchema(), optionalFeatureCols, getParams().contains(RandomForestTrainParams.CATEGORICAL_COLS) ? (String[]) getParams().get(RandomForestTrainParams.CATEGORICAL_COLS) : null));
        this.labels = Preprocessing.generateLabels(select, getParams(), Criteria.isRegression((TreeUtil.TreeType) getParams().get(TreeUtil.TREE_TYPE)));
        BatchOperator<?> castLabel = Preprocessing.castLabel(select, getParams(), this.labels, Criteria.isRegression((TreeUtil.TreeType) getParams().get(TreeUtil.TREE_TYPE)));
        this.stringIndexerModel = Preprocessing.generateStringIndexerModel(castLabel, getParams());
        BatchOperator<?> castWeightCol = Preprocessing.castWeightCol(Preprocessing.castContinuousCols(Preprocessing.castCategoricalCols(castLabel, this.stringIndexerModel, getParams()), getParams()), getParams());
        DataSet<Row> parallelTrain = ((String) getParams().get(RandomForestTrainParams.CREATE_TREE_MODE)).equalsIgnoreCase("PARALLEL") ? parallelTrain(castWeightCol) : seriesTrain(castWeightCol, optionalFeatureCols);
        setSideOutputTables(new Table[]{DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) parallelTrain.reduceGroup(new TreeModelDataConverter.FeatureImportanceReducer()), new String[]{(String) getParams().get(TreeModelDataConverter.IMPORTANCE_FIRST_COL), (String) getParams().get(TreeModelDataConverter.IMPORTANCE_SECOND_COL)}, (TypeInformation<?>[]) new TypeInformation[]{Types.STRING, Types.DOUBLE})});
        setOutput(parallelTrain, new TreeModelDataConverter(FlinkTypeConverter.getFlinkType((String) getParams().get(ModelParamName.LABEL_TYPE_NAME))).getModelSchema());
        return this;
    }

    public static void rewriteLabelType(TableSchema tableSchema, Params params) {
        if (Criteria.isRegression((TreeUtil.TreeType) params.get(TreeUtil.TREE_TYPE))) {
            params.set((ParamInfo<ParamInfo<String>>) ModelParamName.LABEL_TYPE, (ParamInfo<String>) FlinkTypeConverter.getTypeString((TypeInformation<?>) Types.DOUBLE));
        } else {
            params.set((ParamInfo<ParamInfo<String>>) ModelParamName.LABEL_TYPE, (ParamInfo<String>) FlinkTypeConverter.getTypeString(TableUtil.findColTypeWithAssertAndHint(tableSchema, (String) params.get(RandomForestTrainParams.LABEL_COL))));
        }
    }

    private DataSet<Row> parallelTrain(BatchOperator<?> batchOperator) {
        BatchOperator<?> generateQuantileDiscretizerModel = Preprocessing.generateQuantileDiscretizerModel(batchOperator, getParams());
        MapOperator map = Preprocessing.castToQuantile(batchOperator, generateQuantileDiscretizerModel, getParams()).getDataSet().map(new CheckNullValue(batchOperator.getColNames()));
        Params m1495clone = getParams().m1495clone();
        return new IterativeComQueue().setMaxIter(Integer.MAX_VALUE).initWithPartitionedData("treeInput", map).initWithBroadcastData(InitTreeObjs.QUANTILE_MODEL, generateQuantileDiscretizerModel.getDataSet()).initWithBroadcastData("stringIndexerModel", this.stringIndexerModel.getDataSet()).initWithBroadcastData("labels", this.labels).add(new TreeInitObj(m1495clone)).add(new TreeStat()).add(new AllReduce("allReduce", "allReduceCnt")).add(new TreeSplit()).setCompareCriterionOfNode0((CompareCriterionFunction) new Criterion()).closeWith(new SerializeModelCompleteResultFunction(m1495clone)).exec();
    }

    private DataSet<Row> seriesTrain(BatchOperator<?> batchOperator, String[] strArr) {
        DataSet<Row> dataSet = batchOperator.getDataSet();
        MapOperator map = DataSetUtils.countElementsPerPartition(dataSet).sum(1).map(new MapFunction<Tuple2<Integer, Long>, Long>() { // from class: com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp.1
            private static final long serialVersionUID = 2167540828697787410L;

            public Long map(Tuple2<Integer, Long> tuple2) throws Exception {
                return (Long) tuple2.f1;
            }
        });
        MapOperator map2 = this.labels.map(new MapFunction<Object[], Integer>() { // from class: com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp.2
            private static final long serialVersionUID = 993622654277953634L;

            public Integer map(Object[] objArr) throws Exception {
                return Integer.valueOf(objArr.length);
            }
        });
        IterativeDataSet iterate = dataSet.mapPartition(new MapPartitionFunction<Row, Tuple2<Integer, String>>() { // from class: com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp.3
            private static final long serialVersionUID = -1747675717825084026L;

            public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Integer, String>> collector) {
            }
        }).iterate(((Integer) getParams().get(RandomForestTrainParams.NUM_TREES)).intValue());
        SingleInputUdfOperator withBroadcastSet = dataSet.mapPartition(new SampleDataLimit(((Long) get(HasSeed.SEED)).longValue(), ((Double) get(RandomForestTrainParams.SUBSAMPLING_RATIO)).doubleValue(), ((Integer) get(RandomForestTrainParams.NUM_TREES)).intValue())).withBroadcastSet(iterate, "loop").withBroadcastSet(map, "totalCnt").partitionCustom(new AvgPartition(), 0).mapPartition(new SeriesTrainFunction(getParams(), strArr)).withBroadcastSet(this.stringIndexerModel.getDataSet(), "stringIndexerModel").withBroadcastSet(map2, "labelSize").withBroadcastSet(map, "totalCnt");
        return iterate.closeWith(withBroadcastSet, withBroadcastSet.filter(new FilterFunction<Tuple2<Integer, String>>() { // from class: com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp.4
            private static final long serialVersionUID = 7877735883319723407L;

            public boolean filter(Tuple2<Integer, String> tuple2) throws Exception {
                return ((Integer) tuple2.f0).intValue() < 0;
            }
        })).reduceGroup(new SerializeModel(getParams())).withBroadcastSet(this.stringIndexerModel.getDataSet(), "stringIndexerModel").withBroadcastSet(this.labels, "labels");
    }

    public static void rewriteTreeType(Params params) {
        int i = 0;
        StringBuilder sb = new StringBuilder();
        if (params.contains(RandomForestTrainParams.NUM_TREES_OF_INFO_GAIN)) {
            i = 0 + ((Integer) params.get(RandomForestTrainParams.NUM_TREES_OF_INFO_GAIN)).intValue();
        }
        sb.append(i);
        if (params.contains(RandomForestTrainParams.NUM_TREES_OF_GINI)) {
            i += ((Integer) params.get(RandomForestTrainParams.NUM_TREES_OF_GINI)).intValue();
        }
        sb.append(",").append(i);
        if (params.contains(RandomForestTrainParams.NUM_TREES_OF_INFO_GAIN_RATIO)) {
            i += ((Integer) params.get(RandomForestTrainParams.NUM_TREES_OF_INFO_GAIN_RATIO)).intValue();
        }
        if (i > 0) {
            params.set((ParamInfo<ParamInfo<Integer>>) RandomForestTrainParams.NUM_TREES, (ParamInfo<Integer>) Integer.valueOf(i));
            params.set((ParamInfo<ParamInfo<TreeUtil.TreeType>>) TreeUtil.TREE_TYPE, (ParamInfo<TreeUtil.TreeType>) TreeUtil.TreeType.PARTITION);
            params.set((ParamInfo<ParamInfo<String>>) HasTreePartition.TREE_PARTITION, (ParamInfo<String>) sb.toString());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Criteria.Gain getGainFromParams(Params params, int i) {
        TreeUtil.TreeType treeType = (TreeUtil.TreeType) params.get(TreeUtil.TREE_TYPE);
        switch (AnonymousClass5.$SwitchMap$com$alibaba$alink$operator$common$tree$TreeUtil$TreeType[treeType.ordinal()]) {
            case 1:
                return getAvgGain(((Integer) params.get(RandomForestTrainParams.NUM_TREES)).intValue(), i);
            case 2:
                return getIntervalGain((String) params.get(HasTreePartition.TREE_PARTITION), i);
            case 3:
                return Criteria.Gain.MSE;
            case 4:
                return Criteria.Gain.GINI;
            case 5:
                return Criteria.Gain.INFOGAIN;
            case TableUtil.DISPLAY_SIZE /* 6 */:
                return Criteria.Gain.INFOGAINRATIO;
            default:
                throw new AkIllegalOperatorParameterException("Could not parse the gain type from params. type: " + treeType);
        }
    }

    private static Criteria.Gain getIntervalGain(String str, int i) {
        String[] split = str.split(",");
        AkPreconditions.checkState(split.length == 2, "Error format of treeType: " + str);
        return getIntervalGain(Integer.parseInt(split[0]), Integer.parseInt(split[1]), i);
    }

    private static Criteria.Gain getIntervalGain(int i, int i2, int i3) {
        return i3 < i ? Criteria.Gain.INFOGAIN : i3 < i2 ? Criteria.Gain.GINI : Criteria.Gain.INFOGAINRATIO;
    }

    private static Criteria.Gain getAvgGain(int i, int i2) {
        int i3 = i / 3;
        int i4 = i % 3;
        int i5 = i4 < 1 ? i3 : i3 + 1;
        return getIntervalGain(i5, i4 < 2 ? i5 + i3 : i5 + i3 + 1, i2);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ BatchOperator linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
