package com.alibaba.alink.operator.batch.dataproc;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.clustering.dbscan.DbscanConstant;
import com.alibaba.alink.operator.common.dataproc.HugeStringIndexerUtil;
import com.alibaba.alink.operator.common.dataproc.StringIndexerModelDataConverter;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.dataproc.SparseFeatureIndexerTrainParams;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.PriorityQueue;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichFilterFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(value = PortType.MODEL, desc = PortDesc.MODEL_INFO), @PortSpec(value = PortType.DATA, desc = PortDesc.FEATURE_FREQUENCY)})
@ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = {TypeCollections.STRING_TYPE})
@NameCn("稀疏特征编码训练")
@NameEn("Sparse Feature Indexer Train")
/* loaded from: input_file:com/alibaba/alink/operator/batch/dataproc/SparseFeatureIndexerTrainBatchOp.class */
public class SparseFeatureIndexerTrainBatchOp extends BatchOperator<SparseFeatureIndexerTrainBatchOp> implements SparseFeatureIndexerTrainParams<SparseFeatureIndexerTrainBatchOp> {
    private static final long serialVersionUID = 6127001361372601674L;

    public SparseFeatureIndexerTrainBatchOp() {
        this(new Params());
    }

    public SparseFeatureIndexerTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public SparseFeatureIndexerTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        String selectedCol = getSelectedCol();
        final int intValue = getTopN().intValue();
        final int intValue2 = getMinFrequency().intValue();
        final String spareFeatureDelimiter = getSpareFeatureDelimiter();
        final String kvValDelimiter = getKvValDelimiter();
        final boolean booleanValue = getHasValue().booleanValue();
        final double doubleValue = getMinPercent().doubleValue();
        final String[] candidateTags = getCandidateTags();
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        TypeInformation<?> findColType = TableUtil.findColType(checkAndGetFirst.getSchema(), selectedCol);
        if (!findColType.equals(Types.STRING)) {
            throw new AkIllegalDataException("featureColName type must be string, but input type is " + findColType);
        }
        DataSet<Row> dataSet = checkAndGetFirst.select(selectedCol).getDataSet();
        Operator name = DataSetUtils.countElementsPerPartition(dataSet).sum(1).map(new MapFunction<Tuple2<Integer, Long>, Long>() { // from class: com.alibaba.alink.operator.batch.dataproc.SparseFeatureIndexerTrainBatchOp.1
            private static final long serialVersionUID = -8507632108475760763L;

            public Long map(Tuple2<Integer, Long> tuple2) {
                return (Long) tuple2.f1;
            }
        }).name("statics_sample_number");
        DataSet name2 = dataSet.flatMap(new FlatMapFunction<Row, Tuple2<Integer, String>>() { // from class: com.alibaba.alink.operator.batch.dataproc.SparseFeatureIndexerTrainBatchOp.3
            public void flatMap(Row row, Collector<Tuple2<Integer, String>> collector) throws Exception {
                HashSet hashSet = new HashSet();
                for (String str : ((String) row.getField(0)).split(spareFeatureDelimiter)) {
                    if (str.length() != 0) {
                        if (booleanValue) {
                            String[] split = str.split(kvValDelimiter);
                            if (split.length <= 2 && split[0].length() > 0 && !hashSet.contains(split[0])) {
                                hashSet.add(split[0]);
                                collector.collect(Tuple2.of(1, split[0]));
                            }
                        } else {
                            hashSet.add(str);
                            collector.collect(Tuple2.of(1, str));
                        }
                    }
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Tuple2<Integer, String>>) collector);
            }
        }).groupBy(new int[]{1}).reduce(new ReduceFunction<Tuple2<Integer, String>>() { // from class: com.alibaba.alink.operator.batch.dataproc.SparseFeatureIndexerTrainBatchOp.2
            public Tuple2<Integer, String> reduce(Tuple2<Integer, String> tuple2, Tuple2<Integer, String> tuple22) throws Exception {
                return Tuple2.of(Integer.valueOf(((Integer) tuple2.f0).intValue() + ((Integer) tuple22.f0).intValue()), tuple2.f1);
            }
        }).name("split_and_count_fea_frequency");
        if (candidateTags != null && candidateTags.length > 0) {
            name2 = name2.filter(new RichFilterFunction<Tuple2<Integer, String>>() { // from class: com.alibaba.alink.operator.batch.dataproc.SparseFeatureIndexerTrainBatchOp.4
                public boolean filter(Tuple2<Integer, String> tuple2) throws Exception {
                    for (String str : candidateTags) {
                        if (((String) tuple2.f1).contains(str)) {
                            return true;
                        }
                    }
                    return false;
                }
            }).name("filter_candidate_fea_tag");
        }
        setSideOutputTables(new Table[]{DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) name2.map(new MapFunction<Tuple2<Integer, String>, Row>() { // from class: com.alibaba.alink.operator.batch.dataproc.SparseFeatureIndexerTrainBatchOp.5
            public Row map(Tuple2<Integer, String> tuple2) throws Exception {
                return Row.of(new Object[]{tuple2.f1, tuple2.f0});
            }
        }), new String[]{"feature", "feature_count"}, (TypeInformation<?>[]) new TypeInformation[]{Types.STRING, Types.INT})});
        if (intValue2 > 0) {
            name2 = name2.filter(new FilterFunction<Tuple2<Integer, String>>() { // from class: com.alibaba.alink.operator.batch.dataproc.SparseFeatureIndexerTrainBatchOp.6
                public boolean filter(Tuple2<Integer, String> tuple2) throws Exception {
                    return ((Integer) tuple2.f0).intValue() >= intValue2;
                }
            }).name("filter_less_frequency_fea");
        } else if (doubleValue > Criteria.INVALID_GAIN) {
            name2 = name2.filter(new RichFilterFunction<Tuple2<Integer, String>>() { // from class: com.alibaba.alink.operator.batch.dataproc.SparseFeatureIndexerTrainBatchOp.7
                private Integer count;

                public void open(Configuration configuration) throws Exception {
                    this.count = Integer.valueOf((int) Math.floor(((Long) getRuntimeContext().getBroadcastVariable(DbscanConstant.COUNT).get(0)).longValue() * doubleValue));
                }

                public boolean filter(Tuple2<Integer, String> tuple2) throws Exception {
                    return ((Integer) tuple2.f0).intValue() >= this.count.intValue();
                }
            }).withBroadcastSet(name, DbscanConstant.COUNT).name("filter_less_frequency_fea");
        }
        if (intValue > 0) {
            name2 = name2.rebalance().mapPartition(new RichMapPartitionFunction<Tuple2<Integer, String>, Tuple2<Integer, String>>() { // from class: com.alibaba.alink.operator.batch.dataproc.SparseFeatureIndexerTrainBatchOp.9
                private static final long serialVersionUID = 2590378621506355139L;

                public void mapPartition(Iterable<Tuple2<Integer, String>> iterable, Collector<Tuple2<Integer, String>> collector) throws Exception {
                    int numberOfParallelSubtasks = (intValue * 3) / getRuntimeContext().getNumberOfParallelSubtasks();
                    PriorityQueue priorityQueue = new PriorityQueue(Comparator.comparingInt(tuple2 -> {
                        return ((Integer) tuple2.f0).intValue();
                    }));
                    Tuple2 tuple22 = null;
                    Iterator<Tuple2<Integer, String>> it = iterable.iterator();
                    while (it.hasNext()) {
                        tuple22 = SparseFeatureIndexerTrainBatchOp.updateQueue(priorityQueue, numberOfParallelSubtasks, it.next(), tuple22);
                    }
                    Iterator it2 = priorityQueue.iterator();
                    while (it2.hasNext()) {
                        Tuple2 tuple23 = (Tuple2) it2.next();
                        collector.collect(Tuple2.of(tuple23.f0, tuple23.f1));
                    }
                }
            }).name("get_topn_fea_map").reduceGroup(new RichGroupReduceFunction<Tuple2<Integer, String>, Tuple2<Integer, String>>() { // from class: com.alibaba.alink.operator.batch.dataproc.SparseFeatureIndexerTrainBatchOp.8
                public void reduce(Iterable<Tuple2<Integer, String>> iterable, Collector<Tuple2<Integer, String>> collector) throws Exception {
                    PriorityQueue priorityQueue = new PriorityQueue(Comparator.comparingInt(tuple2 -> {
                        return ((Integer) tuple2.f0).intValue();
                    }));
                    Tuple2 tuple22 = null;
                    Iterator<Tuple2<Integer, String>> it = iterable.iterator();
                    while (it.hasNext()) {
                        tuple22 = SparseFeatureIndexerTrainBatchOp.updateQueue(priorityQueue, intValue, it.next(), tuple22);
                    }
                    Iterator it2 = priorityQueue.iterator();
                    while (it2.hasNext()) {
                        collector.collect(it2.next());
                    }
                }
            }).name("get_topn_fea_reduce");
        }
        DataSet<Tuple3<Integer, String, Long>> indexSortedByAlphabet = HugeStringIndexerUtil.indexSortedByAlphabet(name2.map(new MapFunction<Tuple2<Integer, String>, Tuple2<Integer, String>>() { // from class: com.alibaba.alink.operator.batch.dataproc.SparseFeatureIndexerTrainBatchOp.10
            public Tuple2<Integer, String> map(Tuple2<Integer, String> tuple2) throws Exception {
                return Tuple2.of(0, tuple2.f1);
            }
        }), 0L, true);
        final Params params = new Params();
        params.set((ParamInfo<ParamInfo<Boolean>>) SparseFeatureIndexerTrainParams.HAS_VALUE, (ParamInfo<Boolean>) getHasValue());
        params.set((ParamInfo<ParamInfo<String>>) SparseFeatureIndexerTrainParams.SPARSE_FEATURE_DELIMITER, (ParamInfo<String>) getSpareFeatureDelimiter());
        params.set((ParamInfo<ParamInfo<String>>) SparseFeatureIndexerTrainParams.KV_VAL_DELIMITER, (ParamInfo<String>) getKvValDelimiter());
        setOutput((DataSet<Row>) indexSortedByAlphabet.mapPartition(new RichMapPartitionFunction<Tuple3<Integer, String, Long>, Row>() { // from class: com.alibaba.alink.operator.batch.dataproc.SparseFeatureIndexerTrainBatchOp.11
            private static final long serialVersionUID = 8019085781267407813L;

            public void mapPartition(Iterable<Tuple3<Integer, String, Long>> iterable, Collector<Row> collector) throws Exception {
                if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                    collector.collect(Row.of(new Object[]{params.toJson(), null}));
                }
                new StringIndexerModelDataConverter().save2(iterable, collector);
            }
        }).name("build_model").returns(new RowTypeInfo(new StringIndexerModelDataConverter().getModelSchema().getFieldTypes())), new StringIndexerModelDataConverter().getModelSchema());
        return this;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Tuple2<Integer, String> updateQueue(PriorityQueue<Tuple2<Integer, String>> priorityQueue, int i, Tuple2<Integer, String> tuple2, Tuple2<Integer, String> tuple22) {
        if (null == tuple2) {
            return tuple22;
        }
        if (priorityQueue.size() < i) {
            priorityQueue.add(Tuple2.of(tuple2.f0, tuple2.f1));
            tuple22 = priorityQueue.peek();
        } else if (priorityQueue.comparator().compare(tuple22, tuple2) < 0) {
            Tuple2<Integer, String> poll = priorityQueue.poll();
            poll.f0 = tuple2.f0;
            poll.f1 = tuple2.f1;
            priorityQueue.add(poll);
            tuple22 = priorityQueue.peek();
        }
        return tuple22;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ SparseFeatureIndexerTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
