package com.alibaba.alink.operator.batch.nlp;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.dataproc.SortUtils;
import com.alibaba.alink.operator.common.nlp.DocCountVectorizerModelData;
import com.alibaba.alink.operator.common.nlp.DocCountVectorizerModelDataConverter;
import com.alibaba.alink.operator.common.nlp.DocWordSplitCount;
import com.alibaba.alink.operator.common.nlp.FeatureType;
import com.alibaba.alink.operator.common.nlp.WordCountUtil;
import com.alibaba.alink.params.nlp.DocCountVectorizerTrainParams;
import com.alibaba.alink.params.nlp.DocHashCountVectorizerTrainParams;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.util.ArrayList;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = {TypeCollections.STRING_TYPES})
@NameCn("文本特征生成训练")
@NameEn("Doc Count Vectorizer Training")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.nlp.DocCountVectorizer")
/* loaded from: input_file:com/alibaba/alink/operator/batch/nlp/DocCountVectorizerTrainBatchOp.class */
public final class DocCountVectorizerTrainBatchOp extends BatchOperator<DocCountVectorizerTrainBatchOp> implements DocCountVectorizerTrainParams<DocCountVectorizerTrainBatchOp> {
    private static final String WORD_COL_NAME = "word";
    private static final String DOC_WORD_COUNT_COL_NAME = "doc_word_cnt";
    private static final String DOC_COUNT_COL_NAME = "doc_cnt";
    private static final long serialVersionUID = -5063129126354049743L;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/nlp/DocCountVectorizerTrainBatchOp$BuildDocCountModel.class */
    public static class BuildDocCountModel implements MapPartitionFunction<Tuple2<String, Double>, DocCountVectorizerModelData> {
        private static final long serialVersionUID = 4285272379018931290L;
        private String featureType;
        private double minTF;

        public BuildDocCountModel(Params params) {
            this.featureType = ((FeatureType) params.get(DocHashCountVectorizerTrainParams.FEATURE_TYPE)).name();
            this.minTF = ((Double) params.get(DocHashCountVectorizerTrainParams.MIN_TF)).doubleValue();
        }

        public void mapPartition(Iterable<Tuple2<String, Double>> iterable, Collector<DocCountVectorizerModelData> collector) throws Exception {
            ArrayList arrayList = new ArrayList();
            Tuple3 of = Tuple3.of((Object) null, (Object) null, (Object) null);
            int i = 0;
            for (Tuple2<String, Double> tuple2 : iterable) {
                of.f0 = tuple2.f0;
                of.f1 = tuple2.f1;
                int i2 = i;
                i++;
                of.f2 = Integer.valueOf(i2);
                arrayList.add(JsonConverter.toJson(of));
            }
            DocCountVectorizerModelData docCountVectorizerModelData = new DocCountVectorizerModelData();
            docCountVectorizerModelData.featureType = this.featureType;
            docCountVectorizerModelData.minTF = this.minTF;
            docCountVectorizerModelData.list = arrayList;
            collector.collect(docCountVectorizerModelData);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/nlp/DocCountVectorizerTrainBatchOp$CalcIdf.class */
    public static class CalcIdf extends RichGroupReduceFunction<Row, Row> {
        private static final long serialVersionUID = -6966374477296290847L;
        private long docCnt;
        private double maxDF;
        private double minDF;

        public CalcIdf(double d, double d2) {
            this.maxDF = d;
            this.minDF = d2;
        }

        public void open(Configuration configuration) throws Exception {
            this.docCnt = ((Number) ((Row) getRuntimeContext().getBroadcastVariable("docCnt").get(0)).getField(0)).longValue();
            this.maxDF = this.maxDF >= 1.0d ? this.maxDF : this.maxDF * this.docCnt;
            this.minDF = this.minDF >= 1.0d ? this.minDF : this.minDF * this.docCnt;
            if (this.maxDF < this.minDF) {
                throw new AkIllegalOperatorParameterException("MaxDF must be larger than MinDF!");
            }
        }

        public void reduce(Iterable<Row> iterable, Collector<Row> collector) {
            double d = 0.0d;
            double d2 = 0.0d;
            Object obj = null;
            for (Row row : iterable) {
                if (null == obj) {
                    obj = row.getField(0);
                }
                d += 1.0d;
                d2 += ((Number) row.getField(1)).doubleValue();
            }
            if (d < this.minDF || d > this.maxDF) {
                return;
            }
            collector.collect(Row.of(new Object[]{obj, Double.valueOf(-d2), Double.valueOf(Math.log((1.0d + this.docCnt) / (1.0d + d)))}));
        }
    }

    public DocCountVectorizerTrainBatchOp() {
        this(null);
    }

    public DocCountVectorizerTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public DocCountVectorizerTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        setOutput((DataSet<Row>) generateDocCountModel(getParams(), checkAndGetFirst(batchOperatorArr)).mapPartition(new MapPartitionFunction<DocCountVectorizerModelData, Row>() { // from class: com.alibaba.alink.operator.batch.nlp.DocCountVectorizerTrainBatchOp.1
            private static final long serialVersionUID = -246525084223240789L;

            public void mapPartition(Iterable<DocCountVectorizerModelData> iterable, Collector<Row> collector) {
                new DocCountVectorizerModelDataConverter().save(iterable.iterator().next(), collector);
            }
        }), new DocCountVectorizerModelDataConverter().getModelSchema());
        return this;
    }

    public static DataSet<DocCountVectorizerModelData> generateDocCountModel(Params params, BatchOperator batchOperator) {
        Tuple2<DataSet<Tuple2<Integer, Row>>, DataSet<Tuple2<Integer, Long>>> pSort = SortUtils.pSort(batchOperator.udtf((String) params.get(SELECTED_COL), new String[]{"word", DOC_WORD_COUNT_COL_NAME}, new DocWordSplitCount(" "), new String[0]).select(new String[]{"word", DOC_WORD_COUNT_COL_NAME}).getDataSet().groupBy(new int[]{0}).reduceGroup(new CalcIdf(((Double) params.get(MAX_DF)).doubleValue(), ((Double) params.get(MIN_DF)).doubleValue())).withBroadcastSet(batchOperator.select("COUNT(1) AS doc_cnt").getDataSet(), "docCnt"), 1);
        DataSet<Tuple2<Long, Row>> localSort = WordCountUtil.localSort((DataSet) pSort.f0, (DataSet) pSort.f1, 1);
        final int intValue = ((Integer) params.get(VOCAB_SIZE)).intValue();
        return localSort.flatMap(new FlatMapFunction<Tuple2<Long, Row>, Tuple2<String, Double>>() { // from class: com.alibaba.alink.operator.batch.nlp.DocCountVectorizerTrainBatchOp.3
            private static final long serialVersionUID = -1668412648425550909L;

            public void flatMap(Tuple2<Long, Row> tuple2, Collector<Tuple2<String, Double>> collector) throws Exception {
                if (((Long) tuple2.f0).longValue() < intValue) {
                    collector.collect(Tuple2.of(((Row) tuple2.f1).getField(0).toString(), Double.valueOf(((Number) ((Row) tuple2.f1).getField(2)).doubleValue())));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple2<Long, Row>) obj, (Collector<Tuple2<String, Double>>) collector);
            }
        }).partitionCustom(new Partitioner<String>() { // from class: com.alibaba.alink.operator.batch.nlp.DocCountVectorizerTrainBatchOp.2
            private static final long serialVersionUID = 5129015018479212319L;

            public int partition(String str, int i) {
                return 0;
            }
        }, 0).sortPartition(0, Order.DESCENDING).mapPartition(new BuildDocCountModel(params)).setParallelism(1);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ DocCountVectorizerTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
