package com.alibaba.alink.operator.common.nlp;

import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.mapper.SISOModelMapper;
import com.alibaba.alink.common.type.AlinkTypes;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.shaded.guava18.com.google.common.hash.HashFunction;
import org.apache.flink.shaded.guava18.com.google.common.hash.Hashing;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/nlp/DocHashCountVectorizerModelMapper.class */
public class DocHashCountVectorizerModelMapper extends SISOModelMapper {
    private static final long serialVersionUID = -4218147866842462735L;
    private DocHashCountVectorizerModelData model;
    private FeatureType featureType;
    private static final HashFunction HASH = Hashing.murmur3_32(0);

    public DocHashCountVectorizerModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.model = new DocHashCountVectorizerModelDataConverter().load(list);
        this.featureType = FeatureType.valueOf(this.model.featureType.toUpperCase());
    }

    @Override // com.alibaba.alink.common.mapper.SISOModelMapper
    protected TypeInformation initPredResultColType() {
        return AlinkTypes.SPARSE_VECTOR;
    }

    @Override // com.alibaba.alink.common.mapper.SISOModelMapper
    protected Object predictResult(Object obj) {
        if (null == obj) {
            return null;
        }
        HashMap hashMap = new HashMap(0);
        String[] split = ((String) obj).split(" ");
        double length = this.model.minTF >= 1.0d ? this.model.minTF : this.model.minTF * split.length;
        double length2 = 1.0d / split.length;
        for (String str : split) {
            int floorMod = Math.floorMod(Math.abs(HASH.hashUnencodedChars(str).asInt()), this.model.numFeatures);
            if (this.model.idfMap.containsKey(Integer.valueOf(floorMod))) {
                hashMap.merge(Integer.valueOf(floorMod), 1, (v0, v1) -> {
                    return Integer.sum(v0, v1);
                });
            }
        }
        int[] iArr = new int[hashMap.size()];
        double[] dArr = new double[iArr.length];
        int i = 0;
        for (Map.Entry entry : hashMap.entrySet()) {
            double intValue = ((Integer) entry.getValue()).intValue();
            if (intValue >= length) {
                iArr[i] = ((Integer) entry.getKey()).intValue();
                int i2 = i;
                i++;
                dArr[i2] = ((Double) this.featureType.featureValueFunc.apply(this.model.idfMap.get(entry.getKey()), Double.valueOf(intValue), Double.valueOf(length2))).doubleValue();
            }
        }
        return new SparseVector(this.model.numFeatures, Arrays.copyOf(iArr, i), Arrays.copyOf(dArr, i));
    }
}
