package com.alibaba.alink.operator.common.linear;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.RichModelMapper;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.classification.SoftmaxPredictParams;
import java.util.HashMap;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/linear/SoftmaxModelMapper.class */
public class SoftmaxModelMapper extends RichModelMapper {
    private static final long serialVersionUID = -4309479266141950255L;
    protected int vectorColIndex;
    private LinearModelData model;
    private String vectorColName;
    private int[] featureIdx;
    private transient ThreadLocal<DenseVector> threadLocalVec;

    public SoftmaxModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.vectorColIndex = -1;
        if (null != params) {
            this.vectorColName = (String) params.get(SoftmaxPredictParams.VECTOR_COL);
            if (null == this.vectorColName || this.vectorColName.length() == 0) {
                return;
            }
            this.vectorColIndex = TableUtil.findColIndexWithAssert(tableSchema2.getFieldNames(), this.vectorColName);
        }
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.model = new LinearModelDataConverter().load(list);
        TableSchema dataSchema = getDataSchema();
        if (this.vectorColIndex == -1) {
            if (this.model.featureNames == null) {
                this.vectorColName = this.model.vectorColName;
                this.vectorColIndex = TableUtil.findColIndexWithAssert(dataSchema.getFieldNames(), this.model.vectorColName);
                return;
            }
            int length = this.model.featureNames.length;
            this.featureIdx = new int[length];
            String[] fieldNames = dataSchema.getFieldNames();
            for (int i = 0; i < length; i++) {
                this.featureIdx[i] = TableUtil.findColIndexWithAssert(fieldNames, this.model.featureNames[i]);
            }
            this.threadLocalVec = ThreadLocal.withInitial(() -> {
                return new DenseVector(length + (this.model.hasInterceptItem ? 1 : 0));
            });
        }
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Object predictResult(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        return predictResultDetail(slicedSelectedSample).f0;
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Tuple2<Object, String> predictResultDetail(Mapper.SlicedSelectedSample slicedSelectedSample) {
        Vector vector;
        if (this.vectorColIndex != -1) {
            vector = FeatureLabelUtil.getVectorFeature(slicedSelectedSample.get(this.vectorColIndex), this.model.hasInterceptItem, Integer.valueOf(this.model.vectorSize));
        } else {
            vector = this.threadLocalVec.get();
            slicedSelectedSample.fillDenseVector((DenseVector) vector, this.model.hasInterceptItem, this.featureIdx);
        }
        Tuple2<Object, Double[]> predictSoftmaxWithProb = predictSoftmaxWithProb(vector);
        Object obj = predictSoftmaxWithProb.f0;
        HashMap hashMap = new HashMap(0);
        int length = this.model.labelValues.length;
        for (int i = 0; i < length; i++) {
            hashMap.put(this.model.labelValues[i].toString(), ((Double[]) predictSoftmaxWithProb.f1)[i].toString());
        }
        return new Tuple2<>(obj, JsonConverter.gson.toJson(hashMap));
    }

    private Tuple2<Object, Double[]> predictSoftmaxWithProb(Vector vector) {
        DenseVector[] denseVectorArr = this.model.coefVectors;
        int length = this.model.labelValues.length;
        Double[] dArr = new Double[length];
        double d = Double.NEGATIVE_INFINITY;
        int i = -1;
        double d2 = 1.0d;
        for (int i2 = 0; i2 < length - 1; i2++) {
            dArr[i2] = Double.valueOf(Math.exp(FeatureLabelUtil.dot(vector, denseVectorArr[i2])));
            d2 += dArr[i2].doubleValue();
        }
        dArr[length - 1] = Double.valueOf(1.0d);
        for (int i3 = 0; i3 < length; i3++) {
            int i4 = i3;
            dArr[i4] = Double.valueOf(dArr[i4].doubleValue() / d2);
            if (dArr[i3].doubleValue() > d) {
                d = dArr[i3].doubleValue();
                i = i3;
            }
        }
        return new Tuple2<>(this.model.labelValues[i], dArr);
    }
}
