package com.alibaba.alink.operator.common.nlp;

import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.mapper.SISOModelMapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.params.nlp.DocCountVectorizerTrainParams;
import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/nlp/DocCountVectorizerModelMapper.class */
public class DocCountVectorizerModelMapper extends SISOModelMapper {
    private static final Type DATA_TUPLE3_TYPE = new TypeReference<Tuple3<String, Double, Integer>>() { // from class: com.alibaba.alink.operator.common.nlp.DocCountVectorizerModelMapper.1
    }.getType();
    private static final long serialVersionUID = 7431062592310976413L;
    private double minTF;
    private FeatureType featureType;
    private HashMap<String, Tuple2<Integer, Double>> wordIdWeight;
    private int featureNum;

    public DocCountVectorizerModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.featureType = (FeatureType) this.params.get(DocCountVectorizerTrainParams.FEATURE_TYPE);
    }

    @Override // com.alibaba.alink.common.mapper.SISOModelMapper
    protected TypeInformation initPredResultColType() {
        return AlinkTypes.SPARSE_VECTOR;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.wordIdWeight = new HashMap<>(list.size());
        DocCountVectorizerModelData load = new DocCountVectorizerModelDataConverter().load(list);
        this.featureNum = load.list.size();
        this.minTF = load.minTF;
        this.featureType = FeatureType.valueOf(load.featureType.toUpperCase());
        Iterator<String> it = load.list.iterator();
        while (it.hasNext()) {
            Tuple3 tuple3 = (Tuple3) JsonConverter.fromJson(it.next(), DATA_TUPLE3_TYPE);
            this.wordIdWeight.put(tuple3.f0, Tuple2.of(tuple3.f2, tuple3.f1));
        }
    }

    @Override // com.alibaba.alink.common.mapper.SISOModelMapper
    protected Object predictResult(Object obj) {
        if (null == obj) {
            return null;
        }
        return predictSparseVector((String) obj, this.minTF, this.wordIdWeight, this.featureType, this.featureNum);
    }

    public static SparseVector predictSparseVector(String str, double d, HashMap<String, Tuple2<Integer, Double>> hashMap, FeatureType featureType, int i) {
        HashMap hashMap2 = new HashMap(0);
        String[] split = str.split(" ");
        double length = d >= 1.0d ? d : d * split.length;
        double length2 = 1.0d / split.length;
        for (String str2 : split) {
            if (hashMap.containsKey(str2)) {
                hashMap2.merge(str2, 1, (v0, v1) -> {
                    return Integer.sum(v0, v1);
                });
            }
        }
        int[] iArr = new int[hashMap2.size()];
        double[] dArr = new double[iArr.length];
        int i2 = 0;
        for (Map.Entry entry : hashMap2.entrySet()) {
            double intValue = ((Integer) entry.getValue()).intValue();
            if (intValue >= length) {
                Tuple2<Integer, Double> tuple2 = hashMap.get(entry.getKey());
                iArr[i2] = ((Integer) tuple2.f0).intValue();
                int i3 = i2;
                i2++;
                dArr[i3] = ((Double) featureType.featureValueFunc.apply(tuple2.f1, Double.valueOf(intValue), Double.valueOf(length2))).doubleValue();
            }
        }
        return new SparseVector(i, Arrays.copyOf(iArr, i2), Arrays.copyOf(dArr, i2));
    }
}
