package com.alibaba.alink.operator.batch.nlp;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.nlp.DocHashCountVectorizerModelData;
import com.alibaba.alink.operator.common.nlp.DocHashCountVectorizerModelDataConverter;
import com.alibaba.alink.operator.common.nlp.FeatureType;
import com.alibaba.alink.params.nlp.DocHashCountVectorizerTrainParams;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.shaded.guava18.com.google.common.hash.HashFunction;
import org.apache.flink.shaded.guava18.com.google.common.hash.Hashing;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = {TypeCollections.STRING_TYPES})
@NameCn("文本哈希特征生成训练")
@NameEn("DocHash Count Vectorizer Training")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.nlp.DocHashCountVectorizer")
/* loaded from: input_file:com/alibaba/alink/operator/batch/nlp/DocHashCountVectorizerTrainBatchOp.class */
public class DocHashCountVectorizerTrainBatchOp extends BatchOperator<DocHashCountVectorizerTrainBatchOp> implements DocHashCountVectorizerTrainParams<DocHashCountVectorizerTrainBatchOp> {
    private static final long serialVersionUID = 6469196128919853279L;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/nlp/DocHashCountVectorizerTrainBatchOp$BuildModel.class */
    public static class BuildModel implements FlatMapFunction<Tuple2<Long, HashMap<Integer, Double>>, Row> {
        private static final long serialVersionUID = 5598329225247209656L;
        private double minDocFrequency;
        private int numFeatures;
        private String featureType;
        private double minTF;

        public BuildModel(Params params) {
            this.minDocFrequency = ((Double) params.get(DocHashCountVectorizerTrainParams.MIN_DF)).doubleValue();
            this.numFeatures = ((Integer) params.get(DocHashCountVectorizerTrainParams.NUM_FEATURES)).intValue();
            this.featureType = ((FeatureType) params.get(DocHashCountVectorizerTrainParams.FEATURE_TYPE)).name();
            this.minTF = ((Double) params.get(DocHashCountVectorizerTrainParams.MIN_TF)).doubleValue();
        }

        public void flatMap(Tuple2<Long, HashMap<Integer, Double>> tuple2, Collector<Row> collector) throws Exception {
            long longValue = ((Long) tuple2.f0).longValue();
            this.minDocFrequency = this.minDocFrequency >= 1.0d ? this.minDocFrequency : this.minDocFrequency * longValue;
            Iterator it = ((HashMap) tuple2.f1).entrySet().iterator();
            while (it.hasNext()) {
                Map.Entry entry = (Map.Entry) it.next();
                if (((Double) entry.getValue()).doubleValue() >= this.minDocFrequency) {
                    entry.setValue(Double.valueOf(Math.log((longValue + 1.0d) / (((Double) entry.getValue()).doubleValue() + 1.0d))));
                } else {
                    it.remove();
                }
            }
            DocHashCountVectorizerModelData docHashCountVectorizerModelData = new DocHashCountVectorizerModelData();
            docHashCountVectorizerModelData.numFeatures = this.numFeatures;
            docHashCountVectorizerModelData.minTF = this.minTF;
            docHashCountVectorizerModelData.featureType = this.featureType;
            docHashCountVectorizerModelData.idfMap = (HashMap) tuple2.f1;
            new DocHashCountVectorizerModelDataConverter().save(docHashCountVectorizerModelData, collector);
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Tuple2<Long, HashMap<Integer, Double>>) obj, (Collector<Row>) collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/nlp/DocHashCountVectorizerTrainBatchOp$HashingTF.class */
    public static class HashingTF implements MapPartitionFunction<Row, Tuple2<Long, HashMap<Integer, Double>>> {
        private static final long serialVersionUID = -7172651314711810032L;
        private int index;
        private int numFeatures;
        private static final HashFunction HASH = Hashing.murmur3_32(0);

        public HashingTF(int i, int i2) {
            this.index = i;
            this.numFeatures = i2;
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Long, HashMap<Integer, Double>>> collector) throws Exception {
            HashMap hashMap = new HashMap(this.numFeatures);
            long j = 0;
            Iterator<Row> it = iterable.iterator();
            while (it.hasNext()) {
                j++;
                for (String str : ((String) it.next().getField(this.index)).split(" ")) {
                    hashMap.merge(Integer.valueOf(Math.floorMod(Math.abs(HASH.hashUnencodedChars(str).asInt()), this.numFeatures)), Double.valueOf(1.0d), (v0, v1) -> {
                        return Double.sum(v0, v1);
                    });
                }
            }
            collector.collect(Tuple2.of(Long.valueOf(j), hashMap));
        }
    }

    public DocHashCountVectorizerTrainBatchOp() {
        super(new Params());
    }

    public DocHashCountVectorizerTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public DocHashCountVectorizerTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        setOutput((DataSet<Row>) checkAndGetFirst.getDataSet().rebalance().mapPartition(new HashingTF(TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getColNames(), getSelectedCol()), getNumFeatures().intValue())).reduce(new ReduceFunction<Tuple2<Long, HashMap<Integer, Double>>>() { // from class: com.alibaba.alink.operator.batch.nlp.DocHashCountVectorizerTrainBatchOp.1
            private static final long serialVersionUID = 7849950640425941402L;

            public Tuple2<Long, HashMap<Integer, Double>> reduce(Tuple2<Long, HashMap<Integer, Double>> tuple2, Tuple2<Long, HashMap<Integer, Double>> tuple22) {
                ((HashMap) tuple22.f1).forEach((num, d) -> {
                });
                tuple2.f0 = Long.valueOf(((Long) tuple2.f0).longValue() + ((Long) tuple22.f0).longValue());
                return tuple2;
            }
        }).flatMap(new BuildModel(getParams())).setParallelism(1), new DocHashCountVectorizerModelDataConverter().getModelSchema());
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ DocHashCountVectorizerTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
