package com.alibaba.alink.operator.common.tree;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.SISOModelMapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.dataproc.MultiStringIndexerPredictBatchOp;
import com.alibaba.alink.operator.batch.dataproc.MultiStringIndexerTrainBatchOp;
import com.alibaba.alink.operator.batch.feature.QuantileDiscretizerPredictBatchOp;
import com.alibaba.alink.operator.batch.feature.QuantileDiscretizerTrainBatchOp;
import com.alibaba.alink.operator.batch.source.DataSetWrapperBatchOp;
import com.alibaba.alink.operator.batch.source.TableSourceBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp;
import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelDataConverter;
import com.alibaba.alink.operator.common.dataproc.SortUtils;
import com.alibaba.alink.operator.common.dataproc.SortUtilsNext;
import com.alibaba.alink.operator.common.feature.ContinuousRanges;
import com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelDataConverter;
import com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelMapper;
import com.alibaba.alink.operator.common.feature.quantile.PairComparable;
import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable;
import com.alibaba.alink.operator.common.tree.FeatureMeta;
import com.alibaba.alink.params.dataproc.HasHandleInvalid;
import com.alibaba.alink.params.dataproc.HasStringOrderTypeDefaultAsRandom;
import com.alibaba.alink.params.feature.HasDropLast;
import com.alibaba.alink.params.feature.HasEncodeWithoutWoe;
import com.alibaba.alink.params.feature.QuantileDiscretizerTrainParams;
import com.alibaba.alink.params.mapper.SISOMapperParams;
import com.alibaba.alink.params.shared.colname.HasCategoricalCols;
import com.alibaba.alink.params.shared.colname.HasFeatureCols;
import com.alibaba.alink.params.shared.colname.HasFeatureColsDefaultAsNull;
import com.alibaba.alink.params.shared.colname.HasLabelCol;
import com.alibaba.alink.params.shared.colname.HasOutputColDefaultAsNull;
import com.alibaba.alink.params.shared.colname.HasReservedColsDefaultAsNull;
import com.alibaba.alink.params.shared.colname.HasVectorCol;
import com.alibaba.alink.params.shared.colname.HasVectorColDefaultAsNull;
import com.alibaba.alink.params.shared.colname.HasWeightColDefaultAsNull;
import com.alibaba.alink.params.shared.tree.HasMaxBins;
import com.alibaba.alink.params.statistics.HasRoundMode;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.TreeSet;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.StreamSupport;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.functions.RichReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.ReduceOperator;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/Preprocessing.class */
public final class Preprocessing {
    public static final ParamInfo<Boolean> ZERO_AS_MISSING = ParamInfoFactory.createParamInfo("zeroAsMissing", Boolean.class).setHasDefaultValue(false).build();
    private static final Logger LOG = LoggerFactory.getLogger(Preprocessing.class);
    public static final ParamInfo<Long> SAMPLE_COUNT_4_BIN = ParamInfoFactory.createParamInfo("sampleCount4Bin", Long.class).setHasDefaultValue(500000L).build();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/Preprocessing$MultiVector.class */
    public static class MultiVector extends RichMapPartitionFunction<PairComparable, Tuple2<Integer, Number>> {
        private static final Logger LOG = LoggerFactory.getLogger(MultiVector.class);
        private static final long serialVersionUID = 8462783213350057361L;
        private final int quantileNum;
        private final HasRoundMode.RoundMode roundType;
        private final boolean zeroAsMissing;
        private transient int taskId;
        private transient long totalCounts = 0;
        private transient List<Tuple2<Integer, Long>> partitionedCounts;
        private transient List<Tuple2<Integer, Long>> missingCounts;
        private transient List<Tuple2<Integer, Long>> lessZeroCounts;
        private transient List<Tuple2<Integer, Long>> nonzeroCounts;
        private transient List<Tuple2<Integer, Long>> nonzeroOffsets;

        public MultiVector(int i, HasRoundMode.RoundMode roundMode, boolean z) {
            this.quantileNum = i;
            this.roundType = roundMode;
            this.zeroAsMissing = z;
        }

        public void open(Configuration configuration) throws Exception {
            this.partitionedCounts = (List) getRuntimeContext().getBroadcastVariableWithInitializer("partitionedCounts", new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.MultiVector.1
                public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> iterable) {
                    ArrayList arrayList = new ArrayList();
                    Iterator<Tuple2<Integer, Long>> it = iterable.iterator();
                    while (it.hasNext()) {
                        arrayList.add(it.next());
                    }
                    arrayList.sort(Comparator.comparing(tuple2 -> {
                        return (Integer) tuple2.f0;
                    }));
                    return arrayList;
                }

                /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                public /* bridge */ /* synthetic */ Object m616initializeBroadcastVariable(Iterable iterable) {
                    return initializeBroadcastVariable((Iterable<Tuple2<Integer, Long>>) iterable);
                }
            });
            this.totalCounts = ((Long) getRuntimeContext().getBroadcastVariableWithInitializer("totalCounts", new BroadcastVariableInitializer<Long, Long>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.MultiVector.2
                public Long initializeBroadcastVariable(Iterable<Long> iterable) {
                    return iterable.iterator().next();
                }

                /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                public /* bridge */ /* synthetic */ Object m617initializeBroadcastVariable(Iterable iterable) {
                    return initializeBroadcastVariable((Iterable<Long>) iterable);
                }
            })).longValue();
            this.missingCounts = (List) getRuntimeContext().getBroadcastVariableWithInitializer("missingCounts", new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.MultiVector.3
                public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> iterable) {
                    return (List) StreamSupport.stream(iterable.spliterator(), false).sorted(Comparator.comparing(tuple2 -> {
                        return (Integer) tuple2.f0;
                    })).collect(Collectors.toList());
                }

                /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                public /* bridge */ /* synthetic */ Object m618initializeBroadcastVariable(Iterable iterable) {
                    return initializeBroadcastVariable((Iterable<Tuple2<Integer, Long>>) iterable);
                }
            });
            this.lessZeroCounts = (List) getRuntimeContext().getBroadcastVariableWithInitializer("lessZeroCounts", new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.MultiVector.4
                public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> iterable) {
                    return (List) StreamSupport.stream(iterable.spliterator(), false).sorted(Comparator.comparing(tuple2 -> {
                        return (Integer) tuple2.f0;
                    })).collect(Collectors.toList());
                }

                /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                public /* bridge */ /* synthetic */ Object m619initializeBroadcastVariable(Iterable iterable) {
                    return initializeBroadcastVariable((Iterable<Tuple2<Integer, Long>>) iterable);
                }
            });
            this.nonzeroCounts = (List) getRuntimeContext().getBroadcastVariableWithInitializer("nonzeroCounts", new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.MultiVector.5
                public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> iterable) {
                    return (List) StreamSupport.stream(iterable.spliterator(), false).sorted(Comparator.comparing(tuple2 -> {
                        return (Integer) tuple2.f0;
                    })).collect(Collectors.toList());
                }

                /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                public /* bridge */ /* synthetic */ Object m620initializeBroadcastVariable(Iterable iterable) {
                    return initializeBroadcastVariable((Iterable<Tuple2<Integer, Long>>) iterable);
                }
            });
            this.nonzeroOffsets = new ArrayList(this.nonzeroCounts.size());
            long j = 0;
            for (Tuple2<Integer, Long> tuple2 : this.nonzeroCounts) {
                this.nonzeroOffsets.add(Tuple2.of(tuple2.f0, Long.valueOf(j)));
                j += ((Long) tuple2.f1).longValue();
            }
            this.taskId = getRuntimeContext().getIndexOfThisSubtask();
            LOG.info("{} open.", getRuntimeContext().getTaskName());
        }

        public void close() throws Exception {
            super.close();
            LOG.info("{} close.", getRuntimeContext().getTaskName());
        }

        public void mapPartition(Iterable<PairComparable> iterable, Collector<Tuple2<Integer, Number>> collector) throws Exception {
            long j = 0;
            int i = -1;
            int size = this.partitionedCounts.size();
            int i2 = 0;
            while (true) {
                if (i2 >= size) {
                    break;
                }
                int intValue = ((Integer) this.partitionedCounts.get(i2).f0).intValue();
                if (intValue == this.taskId) {
                    i = i2;
                    break;
                } else {
                    if (intValue > this.taskId) {
                        throw new AkUnclassifiedErrorException("Error curId: " + intValue + ". id: " + this.taskId);
                    }
                    j += ((Long) this.partitionedCounts.get(i2).f1).longValue();
                    i2++;
                }
            }
            long longValue = j + ((Long) this.partitionedCounts.get(i).f1).longValue();
            ArrayList arrayList = new ArrayList((int) (longValue - j));
            Iterator<PairComparable> it = iterable.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next());
            }
            if (arrayList.isEmpty()) {
                return;
            }
            if (arrayList.size() != longValue - j) {
                throw new Exception("Error start end. start: " + j + ". end: " + longValue + ". size: " + arrayList.size());
            }
            LOG.info("taskId: {}, size: {}", Integer.valueOf(getRuntimeContext().getIndexOfThisSubtask()), Integer.valueOf(arrayList.size()));
            arrayList.sort(Comparator.naturalOrder());
            int featureIndexInOffsets = (featureOffset(featureIndexInOffsets(longValue)) == longValue ? featureIndexInOffsets(longValue) : featureIndexInOffsets(longValue) + 1) - featureIndexInOffsets(j);
            int i3 = 0;
            for (int i4 = 0; i4 < featureIndexInOffsets; i4++) {
                int featureIndexInOffsets2 = featureIndexInOffsets(j) + i4;
                int nonzeroCount = (int) nonzeroCount(featureIndexInOffsets2);
                int featureOffset = i4 == 0 ? (int) (j - featureOffset(featureIndexInOffsets2)) : 0;
                if (i4 == featureIndexInOffsets - 1) {
                    nonzeroCount = (int) (longValue - featureOffset(featureIndexInOffsets2));
                }
                long notMissingCount = notMissingCount(featureIndexInOffsets2);
                long lessZeroCount = lessZeroCount(featureIndexInOffsets2);
                long zeroCount = zeroCount(featureIndexInOffsets2);
                if (this.zeroAsMissing) {
                    QuantileDiscretizerTrainBatchOp.QIndex qIndex = new QuantileDiscretizerTrainBatchOp.QIndex(notMissingCount - zeroCount, this.quantileNum, this.roundType);
                    for (int i5 = 1; i5 < this.quantileNum; i5++) {
                        long genIndex = qIndex.genIndex(i5);
                        if (genIndex >= featureOffset && genIndex < nonzeroCount) {
                            PairComparable pairComparable = (PairComparable) arrayList.get((int) ((genIndex + i3) - featureOffset));
                            collector.collect(Tuple2.of(pairComparable.first, pairComparable.second));
                        }
                    }
                } else {
                    QuantileDiscretizerTrainBatchOp.QIndex qIndex2 = new QuantileDiscretizerTrainBatchOp.QIndex(notMissingCount, this.quantileNum, this.roundType);
                    for (int i6 = 1; i6 < this.quantileNum; i6++) {
                        long genIndex2 = qIndex2.genIndex(i6);
                        if (genIndex2 < lessZeroCount || genIndex2 >= zeroCount) {
                            if (genIndex2 >= lessZeroCount) {
                                genIndex2 -= zeroCount;
                            }
                            if (genIndex2 >= featureOffset && genIndex2 < nonzeroCount) {
                                PairComparable pairComparable2 = (PairComparable) arrayList.get((int) ((genIndex2 + i3) - featureOffset));
                                collector.collect(Tuple2.of(pairComparable2.first, pairComparable2.second));
                            }
                        } else if (featureOffset == 0) {
                            collector.collect(Tuple2.of(Integer.valueOf(featureIndexInOffsets2), Double.valueOf(Criteria.INVALID_GAIN)));
                        }
                    }
                }
                i3 += nonzeroCount - featureOffset;
            }
        }

        private long notMissingCount(int i) {
            int binarySearch = Collections.binarySearch(this.missingCounts, Tuple2.of(Integer.valueOf(i), 0L), Comparator.comparing(tuple2 -> {
                return (Integer) tuple2.f0;
            }));
            return binarySearch >= 0 ? this.totalCounts - ((Long) this.missingCounts.get(binarySearch).f1).longValue() : this.totalCounts;
        }

        private long nonzeroCount(int i) {
            int binarySearch = Collections.binarySearch(this.nonzeroCounts, Tuple2.of(Integer.valueOf(i), 0L), Comparator.comparing(tuple2 -> {
                return (Integer) tuple2.f0;
            }));
            if (binarySearch >= 0) {
                return ((Long) this.nonzeroCounts.get(binarySearch).f1).longValue();
            }
            return 0L;
        }

        private long zeroCount(int i) {
            int binarySearch = Collections.binarySearch(this.nonzeroCounts, Tuple2.of(Integer.valueOf(i), 0L), Comparator.comparing(tuple2 -> {
                return (Integer) tuple2.f0;
            }));
            return binarySearch >= 0 ? this.totalCounts - ((Long) this.nonzeroCounts.get(binarySearch).f1).longValue() : this.totalCounts;
        }

        private long lessZeroCount(int i) {
            int binarySearch = Collections.binarySearch(this.lessZeroCounts, Tuple2.of(Integer.valueOf(i), 0L), Comparator.comparing(tuple2 -> {
                return (Integer) tuple2.f0;
            }));
            if (binarySearch >= 0) {
                return ((Long) this.lessZeroCounts.get(binarySearch).f1).longValue();
            }
            return 0L;
        }

        private long featureOffset(int i) {
            int binarySearch = Collections.binarySearch(this.nonzeroOffsets, Tuple2.of(Integer.valueOf(i), 0L), Comparator.comparing(tuple2 -> {
                return (Integer) tuple2.f0;
            }));
            return binarySearch >= 0 ? ((Long) this.nonzeroOffsets.get(binarySearch).f1).longValue() : ((Long) this.nonzeroOffsets.get((-binarySearch) - 2).f1).longValue();
        }

        private int featureIndexInOffsets(long j) {
            int binarySearch = Collections.binarySearch(this.nonzeroOffsets, Tuple2.of(0, Long.valueOf(j)), Comparator.comparing(tuple2 -> {
                return (Long) tuple2.f1;
            }));
            return binarySearch >= 0 ? ((Integer) this.nonzeroOffsets.get(binarySearch).f0).intValue() : ((Integer) this.nonzeroOffsets.get((-binarySearch) - 2).f0).intValue();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/Preprocessing$SerializeModel.class */
    public static class SerializeModel implements GroupReduceFunction<Row, Row> {
        private static final long serialVersionUID = -3408433803135796522L;
        private final Params meta;

        public SerializeModel(Params params) {
            this.meta = params;
        }

        public void reduce(Iterable<Row> iterable, Collector<Row> collector) throws Exception {
            HashMap hashMap = new HashMap();
            for (Row row : iterable) {
                int intValue = ((Integer) row.getField(0)).intValue();
                hashMap.put(String.valueOf(intValue), QuantileDiscretizerModelDataConverter.arraySplit2ContinuousRanges(String.valueOf(intValue), Types.DOUBLE, (Number[]) row.getField(1), ((Boolean) this.meta.get(QuantileDiscretizerTrainParams.LEFT_OPEN)).booleanValue()));
            }
            QuantileDiscretizerModelDataConverter quantileDiscretizerModelDataConverter = new QuantileDiscretizerModelDataConverter(hashMap, this.meta);
            quantileDiscretizerModelDataConverter.save(quantileDiscretizerModelDataConverter, collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/Preprocessing$VectorDiscretizer.class */
    public static class VectorDiscretizer implements QuantileDiscretizerModelMapper.NumericQuantileDiscretizer {
        private static final long serialVersionUID = -5893784530000492957L;
        double[] bounds;
        boolean isLeftOpen;
        int[] boundIndex;
        int nullIndex;
        boolean zeroAsMissing;

        public VectorDiscretizer(double[] dArr, boolean z, int[] iArr, int i, boolean z2) {
            this.bounds = dArr;
            this.isLeftOpen = z;
            this.boundIndex = iArr;
            this.nullIndex = i;
            this.zeroAsMissing = z2;
        }

        private int findIndexInner(double d) {
            int i;
            int binarySearch = Arrays.binarySearch(this.bounds, d);
            if (this.isLeftOpen) {
                i = binarySearch >= 0 ? binarySearch - 1 : (-binarySearch) - 2;
            } else {
                i = binarySearch >= 0 ? binarySearch : (-binarySearch) - 2;
            }
            return i;
        }

        @Override // com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelMapper.NumericQuantileDiscretizer
        public boolean isValid(int i) {
            return i != this.nullIndex;
        }

        @Override // com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelMapper.NumericQuantileDiscretizer
        public int findIndex(Object obj) {
            if (Preprocessing.isMissing(((Double) obj).doubleValue(), this.zeroAsMissing)) {
                return this.nullIndex;
            }
            return this.boundIndex[findIndexInner(((Number) obj).doubleValue())];
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/Preprocessing$VectorModel.class */
    private static class VectorModel extends SISOModelMapper {
        private static final long serialVersionUID = 4501962799112695132L;
        private final Map<Integer, VectorDiscretizer> discretizers;

        public VectorModel(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
            super(tableSchema, tableSchema2, params.set((ParamInfo<ParamInfo<String>>) SISOMapperParams.SELECTED_COL, (ParamInfo<String>) params.get(VectorPredictParams.VECTOR_COL)));
            this.discretizers = new HashMap();
        }

        @Override // com.alibaba.alink.common.mapper.SISOModelMapper
        protected TypeInformation<?> initPredResultColType() {
            return AlinkTypes.VECTOR;
        }

        @Override // com.alibaba.alink.common.mapper.SISOModelMapper
        protected Object predictResult(Object obj) throws Exception {
            Vector vector = VectorUtil.getVector(obj);
            if (vector instanceof SparseVector) {
                SparseVector sparseVector = (SparseVector) vector;
                int[] indices = sparseVector.getIndices();
                double[] values = sparseVector.getValues();
                for (int i = 0; i < indices.length; i++) {
                    if (this.discretizers.get(Integer.valueOf(indices[i])) == null) {
                        values[i] = 0.0d;
                    } else {
                        values[i] = r0.findIndex(Double.valueOf(values[i]));
                    }
                }
            } else {
                double[] data = ((DenseVector) vector).getData();
                for (int i2 = 0; i2 < data.length; i2++) {
                    if (this.discretizers.get(Integer.valueOf(i2)) == null) {
                        throw new IllegalArgumentException(String.format("Can not find the discretizer for indices: %d", Integer.valueOf(i2)));
                    }
                    data[i2] = r0.findIndex(Double.valueOf(data[i2]));
                }
            }
            return vector;
        }

        @Override // com.alibaba.alink.common.mapper.ModelMapper
        public void loadModel(List<Row> list) {
            QuantileDiscretizerModelDataConverter quantileDiscretizerModelDataConverter = new QuantileDiscretizerModelDataConverter();
            quantileDiscretizerModelDataConverter.load(list);
            for (Map.Entry<String, ContinuousRanges> entry : quantileDiscretizerModelDataConverter.data.entrySet()) {
                this.discretizers.put(Integer.valueOf(entry.getKey()), Preprocessing.createVectorDiscretizer(entry.getValue(), ((Boolean) quantileDiscretizerModelDataConverter.meta.get(Preprocessing.ZERO_AS_MISSING)).booleanValue()));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/Preprocessing$VectorPredict.class */
    public static final class VectorPredict extends ModelMapBatchOp<VectorPredict> implements VectorPredictParams<VectorPredict> {
        private static final long serialVersionUID = -9334571162681404L;

        public VectorPredict() {
            this(null);
        }

        public VectorPredict(Params params) {
            super(VectorModel::new, params);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/Preprocessing$VectorPredictParams.class */
    private interface VectorPredictParams<T> extends HasVectorCol<T>, HasReservedColsDefaultAsNull<T>, HasOutputColDefaultAsNull<T>, HasHandleInvalid<T>, HasEncodeWithoutWoe<T>, HasDropLast<T> {
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/Preprocessing$VectorTrain.class */
    public static final class VectorTrain extends BatchOperator<VectorTrain> implements QuantileDiscretizerTrainParams<VectorTrain>, HasVectorColDefaultAsNull<VectorTrain> {
        private static final Logger LOG = LoggerFactory.getLogger(VectorTrain.class);
        private static final long serialVersionUID = -1589056627883942993L;

        public VectorTrain() {
            this(null);
        }

        public VectorTrain(Params params) {
            super(params);
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // com.alibaba.alink.operator.batch.BatchOperator
        public VectorTrain linkFrom(BatchOperator<?>... batchOperatorArr) {
            BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
            if (getParams().contains(QuantileDiscretizerTrainParams.NUM_BUCKETS) && getParams().contains(QuantileDiscretizerTrainParams.NUM_BUCKETS_ARRAY)) {
                throw new AkIllegalArgumentException("It can not set num_buckets and num_buckets_array at the same time.");
            }
            setOutput((DataSet<Row>) Preprocessing.toVectorModel(Preprocessing.select(checkAndGetFirst, getVectorCol()).getDataSet(), getNumBuckets().intValue(), (HasRoundMode.RoundMode) getParams().get(HasRoundMode.ROUND_MODE), ((Boolean) getParams().get(Preprocessing.ZERO_AS_MISSING)).booleanValue()).reduceGroup(new SerializeModel(getParams())), new QuantileDiscretizerModelDataConverter().getModelSchema());
            return this;
        }

        @Override // com.alibaba.alink.operator.batch.BatchOperator
        public /* bridge */ /* synthetic */ VectorTrain linkFrom(BatchOperator[] batchOperatorArr) {
            return linkFrom((BatchOperator<?>[]) batchOperatorArr);
        }
    }

    public static DataSet<Object[]> distinctLabels(DataSet<Object> dataSet) {
        return dataSet.map(new MapFunction<Object, Tuple1<Comparable>>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.3
            private static final long serialVersionUID = -6913787844845900748L;

            /* renamed from: map, reason: merged with bridge method [inline-methods] */
            public Tuple1<Comparable> m613map(Object obj) throws Exception {
                return Tuple1.of((Comparable) obj);
            }
        }).groupBy(new int[]{0}).reduce(new ReduceFunction<Tuple1<Comparable>>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.2
            private static final long serialVersionUID = -9106561855242251475L;

            public Tuple1<Comparable> reduce(Tuple1<Comparable> tuple1, Tuple1<Comparable> tuple12) throws Exception {
                return tuple1;
            }
        }).reduceGroup(new GroupReduceFunction<Tuple1<Comparable>, Object[]>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.1
            private static final long serialVersionUID = -6987738021646631957L;

            public void reduce(Iterable<Tuple1<Comparable>> iterable, Collector<Object[]> collector) throws Exception {
                Preprocessing.LOG.info("distinctLabels start");
                collector.collect(StreamSupport.stream(iterable.spliterator(), false).map(tuple1 -> {
                    return (Comparable) tuple1.f0;
                }).sorted(new SortUtils.ComparableComparator()).toArray(i -> {
                    return new Object[i];
                }));
                Preprocessing.LOG.info("distinctLabels end");
            }
        });
    }

    public static DataSet<Row> findIndexOfLabel(DataSet<Row> dataSet, DataSet<Object[]> dataSet2, final int i) {
        return dataSet.map(new RichMapFunction<Row, Row>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.4
            private static final long serialVersionUID = -5365735281787702408L;
            Object[] model;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                this.model = (Object[]) getRuntimeContext().getBroadcastVariableWithInitializer(OptimVariable.model, new BroadcastVariableInitializer<Object[], Object[]>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.4.1
                    public Object[] initializeBroadcastVariable(Iterable<Object[]> iterable) {
                        return iterable.iterator().next();
                    }

                    /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                    public /* bridge */ /* synthetic */ Object m614initializeBroadcastVariable(Iterable iterable) {
                        return initializeBroadcastVariable((Iterable<Object[]>) iterable);
                    }
                });
            }

            public Row map(Row row) throws Exception {
                row.setField(i, Integer.valueOf(Preprocessing.findIdx(this.model, row.getField(i))));
                return row;
            }
        }).withBroadcastSet(dataSet2, OptimVariable.model);
    }

    public static int findIdx(Object[] objArr, Object obj) {
        int binarySearch = Arrays.binarySearch(objArr, obj);
        if (binarySearch >= 0) {
            return binarySearch;
        }
        throw new AkIllegalArgumentException("Can not find " + obj);
    }

    public static DataSet<Object[]> generateLabels(BatchOperator<?> batchOperator, Params params, boolean z) {
        return !z ? distinctLabels(select(batchOperator, (String) params.get(HasLabelCol.LABEL_COL)).getDataSet().map(new MapFunction<Row, Object>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.5
            private static final long serialVersionUID = -3394717275759972231L;

            public Object map(Row row) {
                return row.getField(0);
            }
        })) : MLEnvironmentFactory.get(batchOperator.getMLEnvironmentId()).getExecutionEnvironment().fromElements(new Integer[]{1}).mapPartition(new MapPartitionFunction<Integer, Object[]>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.6
            private static final long serialVersionUID = 475582663950451641L;

            public void mapPartition(Iterable<Integer> iterable, Collector<Object[]> collector) {
            }
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static BatchOperator<?> castLabel(BatchOperator<?> batchOperator, Params params, DataSet<Object[]> dataSet, boolean z) {
        String[] colNames = batchOperator.getColNames();
        if (!z) {
            String str = (String) params.get(HasLabelCol.LABEL_COL);
            TypeInformation<?>[] colTypes = batchOperator.getColTypes();
            batchOperator = (BatchOperator) new DataSetWrapperBatchOp(findIndexOfLabel(batchOperator.getDataSet(), dataSet, TableUtil.findColIndex(colNames, str)), batchOperator.getColNames(), (TypeInformation[]) IntStream.range(0, batchOperator.getColTypes().length).mapToObj(i -> {
                return i == TableUtil.findColIndex(colNames, str) ? Types.INT : colTypes[i];
            }).toArray(i2 -> {
                return new TypeInformation[i2];
            })).setMLEnvironmentId(batchOperator.getMLEnvironmentId());
        } else if (params.contains(HasLabelCol.LABEL_COL)) {
            batchOperator = ((NumericalTypeCast) new NumericalTypeCast().setMLEnvironmentId(batchOperator.getMLEnvironmentId())).setSelectedCols((String) params.get(HasLabelCol.LABEL_COL)).setTargetType("DOUBLE").linkFrom(batchOperator);
        }
        return batchOperator;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static BatchOperator<?> generateStringIndexerModel(BatchOperator<?> batchOperator, Params params) {
        BatchOperator<?> batchOperator2;
        String[] strArr = null;
        if (params.contains(HasCategoricalCols.CATEGORICAL_COLS)) {
            strArr = (String[]) params.get(HasCategoricalCols.CATEGORICAL_COLS);
        }
        if (strArr == null || strArr.length == 0) {
            MultiStringIndexerModelDataConverter multiStringIndexerModelDataConverter = new MultiStringIndexerModelDataConverter();
            batchOperator2 = (BatchOperator) new DataSetWrapperBatchOp(MLEnvironmentFactory.get(batchOperator.getMLEnvironmentId()).getExecutionEnvironment().fromElements(new Integer[]{1}).mapPartition(new MapPartitionFunction<Integer, Row>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.7
                private static final long serialVersionUID = -7481931851291494026L;

                public void mapPartition(Iterable<Integer> iterable, Collector<Row> collector) throws Exception {
                }
            }), multiStringIndexerModelDataConverter.getModelSchema().getFieldNames(), multiStringIndexerModelDataConverter.getModelSchema().getFieldTypes()).setMLEnvironmentId(batchOperator.getMLEnvironmentId());
        } else {
            batchOperator2 = ((MultiStringIndexerTrainBatchOp) new MultiStringIndexerTrainBatchOp().setMLEnvironmentId(batchOperator.getMLEnvironmentId())).setSelectedCols(strArr).setStringOrderType(HasStringOrderTypeDefaultAsRandom.StringOrderType.ALPHABET_ASC).linkFrom(batchOperator);
        }
        return batchOperator2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static BatchOperator<?> castCategoricalCols(BatchOperator<?> batchOperator, BatchOperator<?> batchOperator2, Params params) {
        String[] strArr = (String[]) params.get(HasCategoricalCols.CATEGORICAL_COLS);
        if (strArr != null && strArr.length != 0) {
            BatchOperator<?> select = select(((MultiStringIndexerPredictBatchOp) new MultiStringIndexerPredictBatchOp().setMLEnvironmentId(batchOperator.getMLEnvironmentId())).setHandleInvalid("skip").setSelectedCols(strArr).setReservedCols(batchOperator.getColNames()).linkFrom(batchOperator2, batchOperator), batchOperator.getColNames());
            batchOperator = ((NumericalTypeCast) new NumericalTypeCast().setMLEnvironmentId(select.getMLEnvironmentId())).setSelectedCols(strArr).setTargetType("INT").linkFrom(select);
        }
        return batchOperator;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static BatchOperator<?> castContinuousCols(BatchOperator<?> batchOperator, Params params) {
        String[] strArr = params.contains(HasCategoricalCols.CATEGORICAL_COLS) ? (String[]) ArrayUtils.removeElements((Object[]) params.get(HasFeatureCols.FEATURE_COLS), (Object[]) params.get(HasCategoricalCols.CATEGORICAL_COLS)) : (String[]) params.get(HasFeatureCols.FEATURE_COLS);
        if (strArr != null && strArr.length > 0) {
            batchOperator = ((NumericalTypeCast) new NumericalTypeCast().setMLEnvironmentId(batchOperator.getMLEnvironmentId())).setSelectedCols(strArr).setTargetType("DOUBLE").linkFrom(batchOperator);
        }
        return batchOperator;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static BatchOperator<?> castWeightCol(BatchOperator<?> batchOperator, Params params) {
        String str = (String) params.get(HasWeightColDefaultAsNull.WEIGHT_COL);
        return str == null ? batchOperator : ((NumericalTypeCast) new NumericalTypeCast().setMLEnvironmentId(batchOperator.getMLEnvironmentId())).setSelectedCols(str).setTargetType("DOUBLE").linkFrom(batchOperator);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static BatchOperator<?> generateQuantileDiscretizerModel(BatchOperator<?> batchOperator, Params params) {
        BatchOperator<?> batchOperator2;
        if (params.contains(HasVectorColDefaultAsNull.VECTOR_COL)) {
            return sample(batchOperator, params).linkTo(((VectorTrain) new VectorTrain(new Params().set((ParamInfo<ParamInfo<Boolean>>) ZERO_AS_MISSING, (ParamInfo<Boolean>) params.get(ZERO_AS_MISSING))).setMLEnvironmentId(batchOperator.getMLEnvironmentId())).setVectorCol((String) params.get(HasVectorColDefaultAsNull.VECTOR_COL)).setNumBuckets((Integer) params.get(HasMaxBins.MAX_BINS)));
        }
        String[] strArr = (String[]) ArrayUtils.removeElements((Object[]) params.get(HasFeatureCols.FEATURE_COLS), (Object[]) params.get(HasCategoricalCols.CATEGORICAL_COLS));
        if (strArr == null || strArr.length <= 0) {
            QuantileDiscretizerModelDataConverter quantileDiscretizerModelDataConverter = new QuantileDiscretizerModelDataConverter();
            batchOperator2 = (BatchOperator) new DataSetWrapperBatchOp(MLEnvironmentFactory.get(batchOperator.getMLEnvironmentId()).getExecutionEnvironment().fromElements(new Integer[]{1}).mapPartition(new MapPartitionFunction<Integer, Row>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.8
                private static final long serialVersionUID = 2328781103352773618L;

                public void mapPartition(Iterable<Integer> iterable, Collector<Row> collector) throws Exception {
                }
            }), quantileDiscretizerModelDataConverter.getModelSchema().getFieldNames(), quantileDiscretizerModelDataConverter.getModelSchema().getFieldTypes()).setMLEnvironmentId(batchOperator.getMLEnvironmentId());
        } else {
            batchOperator2 = sample(batchOperator, params).linkTo(((QuantileDiscretizerTrainBatchOp) new QuantileDiscretizerTrainBatchOp(new Params().set((ParamInfo<ParamInfo<Boolean>>) ZERO_AS_MISSING, (ParamInfo<Boolean>) params.get(ZERO_AS_MISSING))).setMLEnvironmentId(batchOperator.getMLEnvironmentId())).setSelectedCols(strArr).setNumBuckets((Integer) params.get(HasMaxBins.MAX_BINS)));
        }
        return batchOperator2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static BatchOperator<?> castToQuantile(BatchOperator<?> batchOperator, BatchOperator<?> batchOperator2, Params params) {
        if (params.contains(HasVectorColDefaultAsNull.VECTOR_COL)) {
            return ((VectorPredict) new VectorPredict().setMLEnvironmentId(batchOperator.getMLEnvironmentId())).setVectorCol((String) params.get(HasVectorColDefaultAsNull.VECTOR_COL)).linkFrom(batchOperator2, batchOperator);
        }
        String[] strArr = (String[]) ArrayUtils.removeElements((Object[]) params.get(HasFeatureCols.FEATURE_COLS), (Object[]) params.get(HasCategoricalCols.CATEGORICAL_COLS));
        if (strArr != null && strArr.length > 0) {
            QuantileDiscretizerPredictBatchOp linkFrom = ((QuantileDiscretizerPredictBatchOp) new QuantileDiscretizerPredictBatchOp().setMLEnvironmentId(batchOperator.getMLEnvironmentId())).setSelectedCols(strArr).setHandleInvalid("SKIP").linkFrom(batchOperator2, batchOperator);
            batchOperator = ((NumericalTypeCast) new NumericalTypeCast().setMLEnvironmentId(linkFrom.getMLEnvironmentId())).setSelectedCols(strArr).setTargetType("INT").linkFrom(linkFrom);
        }
        return batchOperator;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static BatchOperator<?> sample(BatchOperator<?> batchOperator, Params params) {
        DataSet<Row> dataSet = batchOperator.getDataSet();
        MapOperator map = DataSetUtils.countElementsPerPartition(dataSet).sum(1).map(new MapFunction<Tuple2<Integer, Long>, Long>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.9
            private static final long serialVersionUID = -8942137921419703888L;

            public Long map(Tuple2<Integer, Long> tuple2) throws Exception {
                return (Long) tuple2.f1;
            }
        });
        final long currentTimeMillis = System.currentTimeMillis();
        final long longValue = ((Long) params.get(SAMPLE_COUNT_4_BIN)).longValue();
        return (BatchOperator) new DataSetWrapperBatchOp(dataSet.mapPartition(new RichMapPartitionFunction<Row, Row>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.10
            private static final long serialVersionUID = -68193902220003941L;
            double ratio;
            Random random;

            public void open(Configuration configuration) throws Exception {
                this.ratio = Math.min(longValue / ((Long) getRuntimeContext().getBroadcastVariable("totalCount").get(0)).longValue(), 1.0d);
                this.random = new Random(currentTimeMillis + getRuntimeContext().getIndexOfThisSubtask());
                Preprocessing.LOG.info("{} open.", getRuntimeContext().getTaskName());
            }

            public void close() throws Exception {
                super.close();
                Preprocessing.LOG.info("{} close.", getRuntimeContext().getTaskName());
            }

            public void mapPartition(Iterable<Row> iterable, Collector<Row> collector) throws Exception {
                for (Row row : iterable) {
                    if (this.random.nextDouble() < this.ratio) {
                        collector.collect(row);
                    }
                }
            }
        }).withBroadcastSet(map, "totalCount"), batchOperator.getColNames(), batchOperator.getColTypes()).setMLEnvironmentId(batchOperator.getMLEnvironmentId());
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static BatchOperator<?> select(BatchOperator<?> batchOperator, String... strArr) {
        final int[] findColIndicesWithAssertAndHint = TableUtil.findColIndicesWithAssertAndHint(batchOperator.getColNames(), strArr);
        return (BatchOperator) new TableSourceBatchOp(DataSetConversionUtil.toTable(batchOperator.getMLEnvironmentId(), (DataSet<Row>) batchOperator.getDataSet().map(new RichMapFunction<Row, Row>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.11
            private static final long serialVersionUID = 9119490369706910594L;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                Preprocessing.LOG.info("{} open.", getRuntimeContext().getTaskName());
            }

            public void close() throws Exception {
                super.close();
                Preprocessing.LOG.info("{} close.", getRuntimeContext().getTaskName());
            }

            public Row map(Row row) throws Exception {
                Row row2 = new Row(findColIndicesWithAssertAndHint.length);
                for (int i = 0; i < findColIndicesWithAssertAndHint.length; i++) {
                    row2.setField(i, row.getField(findColIndicesWithAssertAndHint[i]));
                }
                return row2;
            }
        }), strArr, TableUtil.findColTypesWithAssertAndHint(batchOperator.getSchema(), strArr))).setMLEnvironmentId(batchOperator.getMLEnvironmentId());
    }

    public static boolean isMissing(Object obj, boolean z, boolean z2) {
        return obj == null || (z && isMissing(((Number) obj).doubleValue(), z2));
    }

    public static boolean isMissing(Object obj, int i) {
        return obj == null || ((Integer) obj).intValue() == i;
    }

    public static boolean isMissing(double d, boolean z) {
        return (z && d == Criteria.INVALID_GAIN) || Double.isNaN(d);
    }

    public static boolean isMissing(long j, boolean z) {
        return z && j == 0;
    }

    public static boolean isMissing(double d, FeatureMeta featureMeta, boolean z) {
        return featureMeta.getType().equals(FeatureMeta.FeatureType.CONTINUOUS) ? isMissing(d, z) : isMissing(Integer.valueOf((int) d), featureMeta.getMissingIndex());
    }

    public static boolean isMissing(Object obj, FeatureMeta featureMeta, boolean z) {
        return featureMeta.getType().equals(FeatureMeta.FeatureType.CONTINUOUS) ? isMissing(obj, true, z) : isMissing(obj, featureMeta.getMissingIndex());
    }

    public static boolean isSparse(Params params) {
        return params.contains(HasVectorColDefaultAsNull.VECTOR_COL);
    }

    public static <T> String[] checkAndGetOptionalFeatureCols(Params params, T t) {
        if (params.contains(HasFeatureColsDefaultAsNull.FEATURE_COLS)) {
            return (String[]) params.get(HasFeatureColsDefaultAsNull.FEATURE_COLS);
        }
        throw new AkIllegalArgumentException("Could not find the feature columns. Please consider to set the feature columns on the " + t.getClass().getName() + ".");
    }

    public static <T> String checkAndGetOptionalVectorCols(Params params, T t) {
        if (params.contains(HasVectorColDefaultAsNull.VECTOR_COL)) {
            return (String) params.get(HasVectorColDefaultAsNull.VECTOR_COL);
        }
        throw new AkIllegalArgumentException("Could not find the vector column. Please consider to set the vector column on the " + t.getClass().getName() + ".");
    }

    public static int zeroIndex(QuantileDiscretizerModelDataConverter quantileDiscretizerModelDataConverter, String str) {
        return createVectorDiscretizer(quantileDiscretizerModelDataConverter.data.get(str), ((Boolean) quantileDiscretizerModelDataConverter.meta.get(ZERO_AS_MISSING)).booleanValue()).findIndex(Double.valueOf(Criteria.INVALID_GAIN));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static VectorDiscretizer createVectorDiscretizer(ContinuousRanges continuousRanges, boolean z) {
        if (!continuousRanges.isFloat()) {
            throw new UnsupportedOperationException("Unsupported now.");
        }
        int length = continuousRanges.splitsArray.length + 1;
        boolean booleanValue = continuousRanges.getLeftOpen().booleanValue();
        int intervalNum = continuousRanges.getIntervalNum();
        int[] array = IntStream.range(0, length + 2).toArray();
        array[length] = length - 1;
        double[] dArr = new double[length + 1];
        dArr[0] = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < length - 1; i++) {
            dArr[i + 1] = continuousRanges.splitsArray[i].doubleValue();
        }
        dArr[length] = Double.POSITIVE_INFINITY;
        return new VectorDiscretizer(dArr, booleanValue, array, intervalNum, z);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static DataSet<Row> toVectorModel(DataSet<Row> dataSet, int i, HasRoundMode.RoundMode roundMode, final boolean z) {
        MapOperator map = DataSetUtils.countElementsPerPartition(dataSet).sum(1).map(new MapFunction<Tuple2<Integer, Long>, Long>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.12
            private static final long serialVersionUID = 2858799989301224611L;

            public Long map(Tuple2<Integer, Long> tuple2) throws Exception {
                return (Long) tuple2.f1;
            }
        });
        ReduceOperator reduce = dataSet.mapPartition(new RichMapPartitionFunction<Row, Tuple2<Integer, Long>>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.14
            private static final long serialVersionUID = 1205331017481743252L;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                VectorTrain.LOG.info("{} open.", getRuntimeContext().getTaskName());
            }

            public void close() throws Exception {
                super.close();
                VectorTrain.LOG.info("{} close.", getRuntimeContext().getTaskName());
            }

            public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Integer, Long>> collector) {
                HashMap hashMap = new HashMap();
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    Vector vector = VectorUtil.getVector(it.next().getField(0));
                    if (vector instanceof SparseVector) {
                        SparseVector sparseVector = (SparseVector) vector;
                        int[] indices = sparseVector.getIndices();
                        double[] values = sparseVector.getValues();
                        for (int i2 = 0; i2 < indices.length; i2++) {
                            if (Preprocessing.isMissing(values[i2], z)) {
                                hashMap.merge(Integer.valueOf(indices[i2]), 1L, (v0, v1) -> {
                                    return Long.sum(v0, v1);
                                });
                            }
                        }
                    } else {
                        double[] data = ((DenseVector) vector).getData();
                        for (int i3 = 0; i3 < data.length; i3++) {
                            if (Preprocessing.isMissing(data[i3], z)) {
                                hashMap.merge(Integer.valueOf(i3), 1L, (v0, v1) -> {
                                    return Long.sum(v0, v1);
                                });
                            }
                        }
                    }
                }
                for (Map.Entry entry : hashMap.entrySet()) {
                    collector.collect(Tuple2.of(entry.getKey(), entry.getValue()));
                }
            }
        }).groupBy(new int[]{0}).reduce(new RichReduceFunction<Tuple2<Integer, Long>>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.13
            private static final long serialVersionUID = -2194135190247682594L;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                VectorTrain.LOG.info("{} open.", getRuntimeContext().getTaskName());
            }

            public void close() throws Exception {
                super.close();
                VectorTrain.LOG.info("{} close.", getRuntimeContext().getTaskName());
            }

            public Tuple2<Integer, Long> reduce(Tuple2<Integer, Long> tuple2, Tuple2<Integer, Long> tuple22) {
                return Tuple2.of(tuple2.f0, Long.valueOf(((Long) tuple2.f1).longValue() + ((Long) tuple22.f1).longValue()));
            }
        });
        ReduceOperator reduce2 = dataSet.mapPartition(new RichMapPartitionFunction<Row, Tuple2<Integer, Long>>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.16
            private static final long serialVersionUID = -4514110269239741447L;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                VectorTrain.LOG.info("{} open.", getRuntimeContext().getTaskName());
            }

            public void close() throws Exception {
                super.close();
                VectorTrain.LOG.info("{} close.", getRuntimeContext().getTaskName());
            }

            public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Integer, Long>> collector) {
                HashMap hashMap = new HashMap();
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    Vector vector = VectorUtil.getVector(it.next().getField(0));
                    if (vector instanceof SparseVector) {
                        for (int i2 : ((SparseVector) vector).getIndices()) {
                            hashMap.merge(Integer.valueOf(i2), 1L, (v0, v1) -> {
                                return Long.sum(v0, v1);
                            });
                        }
                    } else {
                        double[] data = ((DenseVector) vector).getData();
                        for (int i3 = 0; i3 < data.length; i3++) {
                            hashMap.merge(Integer.valueOf(i3), 1L, (v0, v1) -> {
                                return Long.sum(v0, v1);
                            });
                        }
                    }
                }
                for (Map.Entry entry : hashMap.entrySet()) {
                    collector.collect(Tuple2.of(entry.getKey(), entry.getValue()));
                }
            }
        }).groupBy(new int[]{0}).reduce(new RichReduceFunction<Tuple2<Integer, Long>>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.15
            private static final long serialVersionUID = 4055269324161485854L;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                VectorTrain.LOG.info("{} open.", getRuntimeContext().getTaskName());
            }

            public void close() throws Exception {
                super.close();
                VectorTrain.LOG.info("{} close.", getRuntimeContext().getTaskName());
            }

            public Tuple2<Integer, Long> reduce(Tuple2<Integer, Long> tuple2, Tuple2<Integer, Long> tuple22) {
                return Tuple2.of(tuple2.f0, Long.valueOf(((Long) tuple2.f1).longValue() + ((Long) tuple22.f1).longValue()));
            }
        });
        ReduceOperator reduce3 = dataSet.mapPartition(new RichMapPartitionFunction<Row, Tuple2<Integer, Long>>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.18
            private static final long serialVersionUID = -3948602601761057786L;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                VectorTrain.LOG.info("{} open.", getRuntimeContext().getTaskName());
            }

            public void close() throws Exception {
                super.close();
                VectorTrain.LOG.info("{} close.", getRuntimeContext().getTaskName());
            }

            public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Integer, Long>> collector) {
                HashMap hashMap = new HashMap();
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    Vector vector = VectorUtil.getVector(it.next().getField(0));
                    if (vector instanceof SparseVector) {
                        SparseVector sparseVector = (SparseVector) vector;
                        int[] indices = sparseVector.getIndices();
                        double[] values = sparseVector.getValues();
                        for (int i2 = 0; i2 < indices.length; i2++) {
                            if (!Preprocessing.isMissing(values[i2], z) && values[i2] < Criteria.INVALID_GAIN) {
                                hashMap.merge(Integer.valueOf(indices[i2]), 1L, (v0, v1) -> {
                                    return Long.sum(v0, v1);
                                });
                            }
                        }
                    } else {
                        double[] data = ((DenseVector) vector).getData();
                        for (int i3 = 0; i3 < data.length; i3++) {
                            if (!Preprocessing.isMissing(data[i3], z) && data[i3] < Criteria.INVALID_GAIN) {
                                hashMap.merge(Integer.valueOf(i3), 1L, (v0, v1) -> {
                                    return Long.sum(v0, v1);
                                });
                            }
                        }
                    }
                }
                for (Map.Entry entry : hashMap.entrySet()) {
                    collector.collect(Tuple2.of(entry.getKey(), entry.getValue()));
                }
            }
        }).groupBy(new int[]{0}).reduce(new RichReduceFunction<Tuple2<Integer, Long>>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.17
            private static final long serialVersionUID = -2353281797919459915L;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                VectorTrain.LOG.info("{} open.", getRuntimeContext().getTaskName());
            }

            public void close() throws Exception {
                super.close();
                VectorTrain.LOG.info("{} close.", getRuntimeContext().getTaskName());
            }

            public Tuple2<Integer, Long> reduce(Tuple2<Integer, Long> tuple2, Tuple2<Integer, Long> tuple22) {
                return Tuple2.of(tuple2.f0, Long.valueOf(((Long) tuple2.f1).longValue() + ((Long) tuple22.f1).longValue()));
            }
        });
        Tuple2 pSort = SortUtilsNext.pSort(dataSet.mapPartition(new RichMapPartitionFunction<Row, PairComparable>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.19
            private static final long serialVersionUID = 4270417871029343805L;
            PairComparable pairBuff;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                VectorTrain.LOG.info("{} open.", getRuntimeContext().getTaskName());
                this.pairBuff = new PairComparable();
            }

            public void close() throws Exception {
                super.close();
                VectorTrain.LOG.info("{} close.", getRuntimeContext().getTaskName());
            }

            public void mapPartition(Iterable<Row> iterable, Collector<PairComparable> collector) {
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    Vector vector = VectorUtil.getVector(it.next().getField(0));
                    if (vector instanceof SparseVector) {
                        SparseVector sparseVector = (SparseVector) vector;
                        int[] indices = sparseVector.getIndices();
                        double[] values = sparseVector.getValues();
                        for (int i2 = 0; i2 < indices.length; i2++) {
                            this.pairBuff.first = Integer.valueOf(indices[i2]);
                            this.pairBuff.second = Preprocessing.isMissing(values[i2], z) ? null : Double.valueOf(values[i2]);
                            collector.collect(this.pairBuff);
                        }
                    } else {
                        double[] data = ((DenseVector) vector).getData();
                        for (int i3 = 0; i3 < data.length; i3++) {
                            this.pairBuff.first = Integer.valueOf(i3);
                            this.pairBuff.second = Preprocessing.isMissing(data[i3], z) ? null : Double.valueOf(data[i3]);
                            collector.collect(this.pairBuff);
                        }
                    }
                }
            }
        }));
        return ((DataSet) pSort.f0).mapPartition(new MultiVector(i, roundMode, z)).withBroadcastSet((DataSet) pSort.f1, "partitionedCounts").withBroadcastSet(map, "totalCounts").withBroadcastSet(reduce, "missingCounts").withBroadcastSet(reduce2, "nonzeroCounts").withBroadcastSet(reduce3, "lessZeroCounts").groupBy(new int[]{0}).reduceGroup(new RichGroupReduceFunction<Tuple2<Integer, Number>, Row>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.20
            private static final long serialVersionUID = 1349754303677527939L;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                VectorTrain.LOG.info("{} open.", getRuntimeContext().getTaskName());
            }

            public void close() throws Exception {
                super.close();
                VectorTrain.LOG.info("{} close.", getRuntimeContext().getTaskName());
            }

            public void reduce(Iterable<Tuple2<Integer, Number>> iterable, Collector<Row> collector) throws Exception {
                TreeSet treeSet = new TreeSet(new Comparator<Number>() { // from class: com.alibaba.alink.operator.common.tree.Preprocessing.20.1
                    @Override // java.util.Comparator
                    public int compare(Number number, Number number2) {
                        return SortUtils.OBJECT_COMPARATOR.compare(number, number2);
                    }
                });
                int i2 = -1;
                for (Tuple2<Integer, Number> tuple2 : iterable) {
                    i2 = ((Integer) tuple2.f0).intValue();
                    treeSet.add(tuple2.f1);
                }
                collector.collect(Row.of(new Object[]{Integer.valueOf(i2), treeSet.toArray(new Number[0])}));
            }
        });
    }
}
