package com.alibaba.alink.operator.local.nlp;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.RowCollector;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.nlp.DocCountVectorizerModelData;
import com.alibaba.alink.operator.common.nlp.DocCountVectorizerModelDataConverter;
import com.alibaba.alink.operator.common.nlp.FeatureType;
import com.alibaba.alink.operator.local.LocalOperator;
import com.alibaba.alink.params.nlp.DocCountVectorizerTrainParams;
import com.alibaba.alink.params.nlp.DocHashCountVectorizerTrainParams;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = {TypeCollections.STRING_TYPES})
@NameCn("文本特征生成训练")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.nlp.DocCountVectorizer")
/* loaded from: input_file:com/alibaba/alink/operator/local/nlp/DocCountVectorizerTrainLocalOp.class */
public final class DocCountVectorizerTrainLocalOp extends LocalOperator<DocCountVectorizerTrainLocalOp> implements DocCountVectorizerTrainParams<DocCountVectorizerTrainLocalOp> {
    private static final String WORD_COL_NAME = "word";
    private static final String DOC_WORD_COUNT_COL_NAME = "doc_word_cnt";
    private static final String DOC_COUNT_COL_NAME = "doc_cnt";

    public DocCountVectorizerTrainLocalOp() {
        this(null);
    }

    public DocCountVectorizerTrainLocalOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.local.LocalOperator
    public DocCountVectorizerTrainLocalOp linkFrom(LocalOperator<?>... localOperatorArr) {
        DocCountVectorizerModelData generateDocCountModel = generateDocCountModel(getParams(), checkAndGetFirst(localOperatorArr));
        RowCollector rowCollector = new RowCollector();
        new DocCountVectorizerModelDataConverter().save(generateDocCountModel, rowCollector);
        setOutputTable(new MTable(rowCollector.getRows(), new DocCountVectorizerModelDataConverter().getModelSchema()));
        return this;
    }

    public static DocCountVectorizerModelData generateDocCountModel(Params params, LocalOperator localOperator) {
        int findColIndexWithAssert = TableUtil.findColIndexWithAssert(localOperator.getSchema(), (String) params.get(SELECTED_COL));
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        HashSet hashSet = new HashSet();
        Iterator<Row> it = localOperator.getOutputTable().getRows().iterator();
        while (it.hasNext()) {
            String obj = it.next().getField(findColIndexWithAssert).toString();
            if (null != obj && obj.length() != 0) {
                hashSet.clear();
                for (String str : obj.split(" ")) {
                    if (str.length() > 0) {
                        hashSet.add(str);
                        hashMap.merge(str, Double.valueOf(1.0d), (v0, v1) -> {
                            return Double.sum(v0, v1);
                        });
                    }
                }
                Iterator it2 = hashSet.iterator();
                while (it2.hasNext()) {
                    hashMap2.merge((String) it2.next(), Double.valueOf(1.0d), (v0, v1) -> {
                        return Double.sum(v0, v1);
                    });
                }
            }
        }
        long numRow = localOperator.getOutputTable().getNumRow();
        double doubleValue = ((Double) params.get(MAX_DF)).doubleValue();
        double doubleValue2 = ((Double) params.get(MIN_DF)).doubleValue();
        ArrayList arrayList = new ArrayList();
        for (Map.Entry entry : hashMap.entrySet()) {
            double doubleValue3 = ((Double) hashMap2.get(entry.getKey())).doubleValue();
            if (doubleValue3 >= doubleValue2 && doubleValue3 <= doubleValue) {
                arrayList.add(Row.of(new Object[]{entry.getKey(), entry.getValue(), Double.valueOf(Math.log((1.0d + numRow) / (1.0d + doubleValue3)))}));
            }
        }
        MTable mTable = new MTable(arrayList, "word string, doc_word_cnt double, doc_cnt double");
        mTable.orderBy(new String[]{DOC_WORD_COUNT_COL_NAME}, new boolean[]{false});
        int intValue = ((Integer) params.get(VOCAB_SIZE)).intValue();
        if (mTable.getNumRow() > intValue) {
            mTable = mTable.subTable(0, intValue);
        }
        String name = ((FeatureType) params.get(DocHashCountVectorizerTrainParams.FEATURE_TYPE)).name();
        double doubleValue4 = ((Double) params.get(DocHashCountVectorizerTrainParams.MIN_TF)).doubleValue();
        ArrayList arrayList2 = new ArrayList();
        Tuple3 of = Tuple3.of((Object) null, (Object) null, (Object) null);
        int i = 0;
        for (Row row : mTable.getRows()) {
            of.f0 = (String) row.getField(0);
            of.f1 = (Double) row.getField(2);
            int i2 = i;
            i++;
            of.f2 = Integer.valueOf(i2);
            arrayList2.add(JsonConverter.toJson(of));
        }
        DocCountVectorizerModelData docCountVectorizerModelData = new DocCountVectorizerModelData();
        docCountVectorizerModelData.featureType = name;
        docCountVectorizerModelData.minTF = doubleValue4;
        docCountVectorizerModelData.list = arrayList2;
        return docCountVectorizerModelData;
    }

    @Override // com.alibaba.alink.operator.local.LocalOperator
    public /* bridge */ /* synthetic */ DocCountVectorizerTrainLocalOp linkFrom(LocalOperator[] localOperatorArr) {
        return linkFrom((LocalOperator<?>[]) localOperatorArr);
    }
}
