package com.alibaba.alink.operator.common.tree.parallelcart.data;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData;
import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelDataConverter;
import com.alibaba.alink.operator.common.feature.ContinuousRanges;
import com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelDataConverter;
import com.alibaba.alink.operator.common.tree.FeatureMeta;
import com.alibaba.alink.operator.common.tree.Node;
import com.alibaba.alink.operator.common.tree.Preprocessing;
import com.alibaba.alink.operator.common.tree.TreeUtil;
import com.alibaba.alink.operator.common.tree.parallelcart.BaseGbdtTrainBatchOp;
import com.alibaba.alink.operator.common.tree.parallelcart.BuildLocalSketch;
import com.alibaba.alink.operator.common.tree.parallelcart.EpsilonApproQuantile;
import com.alibaba.alink.params.classification.GbdtTrainParams;
import com.alibaba.alink.params.shared.colname.HasCategoricalCols;
import com.alibaba.alink.params.shared.colname.HasFeatureCols;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/data/DataUtil.class */
public class DataUtil {
    public static Data createData(Params params, List<FeatureMeta> list, int i, boolean z) {
        list.sort(Comparator.comparingInt((v0) -> {
            return v0.getIndex();
        }));
        int i2 = 0;
        int i3 = -1;
        for (FeatureMeta featureMeta : list) {
            int i4 = i2;
            i2++;
            AkPreconditions.checkState(featureMeta.getIndex() == i4, "There are empty columns. index: %d", Integer.valueOf(featureMeta.getIndex()));
            i3 = Math.max(i3, featureMeta.getIndex());
        }
        int i5 = i3 + 1;
        return Preprocessing.isSparse(params) ? new SparseData(params, (FeatureMeta[]) list.toArray(new FeatureMeta[0]), i, i5) : new DenseData(params, (FeatureMeta[]) list.toArray(new FeatureMeta[0]), i, i5, z);
    }

    public static boolean left(int i, Node node, FeatureMeta featureMeta) {
        return (node.getMissingSplit() == null || !Preprocessing.isMissing(Integer.valueOf(i), featureMeta.getMissingIndex())) ? node.getCategoricalSplit() == null ? ((double) i) <= node.getContinuousSplit() : node.getCategoricalSplit()[i] == 0 : node.getMissingSplit()[0] == 0;
    }

    public static boolean leftUseSummary(double d, Node node, EpsilonApproQuantile.WQSummary wQSummary, FeatureMeta featureMeta, boolean z) {
        return (node.getMissingSplit() == null || !Preprocessing.isMissing(d, featureMeta, z)) ? d <= wQSummary.entries.get((int) node.getContinuousSplit()).value : node.getMissingSplit()[0] == 0;
    }

    public static int getFeatureCategoricalSize(FeatureMeta featureMeta, boolean z) {
        return z ? featureMeta.getNumCategorical() + 1 : featureMeta.getNumCategorical();
    }

    public static DataSet<FeatureMeta> createContinuousMetaFromQuantile(DataSet<Row> dataSet, final Params params) {
        return dataSet.reduceGroup(new GroupReduceFunction<Row, FeatureMeta>() { // from class: com.alibaba.alink.operator.common.tree.parallelcart.data.DataUtil.1
            private static final long serialVersionUID = 2798937980102249481L;

            public void reduce(Iterable<Row> iterable, Collector<FeatureMeta> collector) throws Exception {
                ArrayList arrayList = new ArrayList();
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next());
                }
                QuantileDiscretizerModelDataConverter quantileDiscretizerModelDataConverter = new QuantileDiscretizerModelDataConverter();
                quantileDiscretizerModelDataConverter.load(arrayList);
                if (Preprocessing.isSparse(Params.this)) {
                    for (Map.Entry<String, ContinuousRanges> entry : quantileDiscretizerModelDataConverter.data.entrySet()) {
                        collector.collect(new FeatureMeta(entry.getKey(), Integer.parseInt(entry.getKey()), FeatureMeta.FeatureType.CONTINUOUS, quantileDiscretizerModelDataConverter.getFeatureSize(entry.getKey()), Preprocessing.zeroIndex(quantileDiscretizerModelDataConverter, entry.getKey()), quantileDiscretizerModelDataConverter.missingIndex(entry.getKey())));
                    }
                    return;
                }
                for (Map.Entry<String, ContinuousRanges> entry2 : quantileDiscretizerModelDataConverter.data.entrySet()) {
                    collector.collect(new FeatureMeta(entry2.getKey(), TableUtil.findColIndex((String[]) Params.this.get(HasFeatureCols.FEATURE_COLS), entry2.getKey()), FeatureMeta.FeatureType.CONTINUOUS, quantileDiscretizerModelDataConverter.getFeatureSize(entry2.getKey()), -1, quantileDiscretizerModelDataConverter.missingIndex(entry2.getKey())));
                }
            }
        });
    }

    public static DataSet<FeatureMeta> createCategoricalMetaFromStringIndexer(DataSet<Row> dataSet, final Params params) {
        return dataSet.reduceGroup(new GroupReduceFunction<Row, FeatureMeta>() { // from class: com.alibaba.alink.operator.common.tree.parallelcart.data.DataUtil.2
            private static final long serialVersionUID = 8374084070326048116L;

            public void reduce(Iterable<Row> iterable, Collector<FeatureMeta> collector) throws Exception {
                ArrayList arrayList = new ArrayList();
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next());
                }
                MultiStringIndexerModelData load = new MultiStringIndexerModelDataConverter().load((List<Row>) arrayList);
                ArrayList arrayList2 = new ArrayList();
                if (Params.this.contains(HasCategoricalCols.CATEGORICAL_COLS) && Params.this.get(HasCategoricalCols.CATEGORICAL_COLS) != null) {
                    arrayList2.addAll(Arrays.asList((Object[]) Params.this.get(HasCategoricalCols.CATEGORICAL_COLS)));
                }
                for (Map.Entry<String, Integer> entry : TreeUtil.extractCategoricalColsSize(load, (String[]) arrayList2.toArray(new String[0])).entrySet()) {
                    collector.collect(new FeatureMeta(entry.getKey(), TableUtil.findColIndex((String[]) Params.this.get(HasFeatureCols.FEATURE_COLS), entry.getKey()), FeatureMeta.FeatureType.CATEGORICAL, entry.getValue().intValue(), -1, entry.getValue().intValue()));
                }
            }
        });
    }

    public static DataSet<FeatureMeta> createFeatureMetas(DataSet<Row> dataSet, DataSet<Row> dataSet2, Params params) {
        return createCategoricalMetaFromStringIndexer(dataSet2, params).union(createContinuousMetaFromQuantile(dataSet, params));
    }

    public static DataSet<FeatureMeta> createOneHotFeatureMeta(DataSet<Row> dataSet, Params params, String[] strArr) {
        return maxIndexOfVector(dataSet, params, strArr).flatMap(new FlatMapFunction<Tuple1<Integer>, FeatureMeta>() { // from class: com.alibaba.alink.operator.common.tree.parallelcart.data.DataUtil.3
            private static final long serialVersionUID = -3553369535923306216L;

            public void flatMap(Tuple1<Integer> tuple1, Collector<FeatureMeta> collector) throws Exception {
                for (int i = 0; i < ((Integer) tuple1.f0).intValue(); i++) {
                    collector.collect(new FeatureMeta(String.valueOf(i), i, FeatureMeta.FeatureType.CONTINUOUS, 2, 0, 2));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple1<Integer>) obj, (Collector<FeatureMeta>) collector);
            }
        });
    }

    public static DataSet<FeatureMeta> createEpsilonApproQuantileFeatureMeta(DataSet<Row> dataSet, DataSet<Row> dataSet2, final Params params, String[] strArr, long j) {
        if (params.contains(GbdtTrainParams.VECTOR_COL)) {
            return maxIndexOfVector(dataSet, params, strArr).flatMap(new FlatMapFunction<Tuple1<Integer>, FeatureMeta>() { // from class: com.alibaba.alink.operator.common.tree.parallelcart.data.DataUtil.4
                private static final long serialVersionUID = 2747512134712849290L;

                public void flatMap(Tuple1<Integer> tuple1, Collector<FeatureMeta> collector) {
                    for (int i = 0; i < ((Integer) tuple1.f0).intValue(); i++) {
                        collector.collect(new FeatureMeta(String.valueOf(i), i, FeatureMeta.FeatureType.CONTINUOUS, BuildLocalSketch.maxSize(((Double) Params.this.get(BaseGbdtTrainBatchOp.SKETCH_RATIO)).doubleValue(), ((Double) Params.this.get(BaseGbdtTrainBatchOp.SKETCH_EPS)).doubleValue()), -1, -1));
                    }
                }

                public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                    flatMap((Tuple1<Integer>) obj, (Collector<FeatureMeta>) collector);
                }
            });
        }
        return createCategoricalMetaFromStringIndexer(dataSet2, params).union(MLEnvironmentFactory.get(Long.valueOf(j)).getExecutionEnvironment().fromElements((String[]) ArrayUtils.removeElements((Object[]) params.get(HasFeatureCols.FEATURE_COLS), (Object[]) params.get(HasCategoricalCols.CATEGORICAL_COLS))).map(new MapFunction<String, FeatureMeta>() { // from class: com.alibaba.alink.operator.common.tree.parallelcart.data.DataUtil.5
            private static final long serialVersionUID = -5825333642221837526L;

            public FeatureMeta map(String str) {
                return new FeatureMeta(str, TableUtil.findColIndex((String[]) Params.this.get(HasFeatureCols.FEATURE_COLS), str), FeatureMeta.FeatureType.CONTINUOUS, BuildLocalSketch.maxSize(((Double) Params.this.get(BaseGbdtTrainBatchOp.SKETCH_RATIO)).doubleValue(), ((Double) Params.this.get(BaseGbdtTrainBatchOp.SKETCH_EPS)).doubleValue()), -1, -1);
            }
        }));
    }

    public static DataSet<Tuple1<Integer>> maxIndexOfVector(DataSet<Row> dataSet, Params params, String[] strArr) {
        final int findColIndex = TableUtil.findColIndex(strArr, (String) params.get(GbdtTrainParams.VECTOR_COL));
        return dataSet.mapPartition(new MapPartitionFunction<Row, Tuple1<Integer>>() { // from class: com.alibaba.alink.operator.common.tree.parallelcart.data.DataUtil.6
            private static final long serialVersionUID = 704094286836681792L;

            public void mapPartition(Iterable<Row> iterable, Collector<Tuple1<Integer>> collector) throws Exception {
                int i = -1;
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    Vector vector = VectorUtil.getVector(it.next().getField(findColIndex));
                    if (!(vector instanceof SparseVector) || vector.size() >= 0) {
                        i = Math.max(i, vector.size());
                    } else {
                        for (int i2 : ((SparseVector) vector).getIndices()) {
                            i = Math.max(i, i2 + 1);
                        }
                    }
                }
                collector.collect(Tuple1.of(Integer.valueOf(i)));
            }
        }).max(0);
    }
}
