package com.alibaba.alink.operator.common.fm;

import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.RichModelMapper;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp;
import com.alibaba.alink.operator.common.linear.FeatureLabelUtil;
import com.alibaba.alink.operator.common.optim.FmOptimizer;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.recommendation.FmPredictParams;
import java.util.HashMap;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/fm/FmModelMapper.class */
public class FmModelMapper extends RichModelMapper {
    private static final long serialVersionUID = -6348182372481296494L;
    private int vectorColIndex;
    private FmModelData model;
    private int[] dim;
    private int[] featureIdx;
    private final TypeInformation<?> labelType;
    private transient ThreadLocal<DenseVector> threadLocalVec;

    public FmModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        String str;
        this.vectorColIndex = -1;
        if (null != params && null != (str = (String) params.get(FmPredictParams.VECTOR_COL)) && str.length() != 0) {
            this.vectorColIndex = TableUtil.findColIndexWithAssert(tableSchema2.getFieldNames(), str);
        }
        this.labelType = tableSchema.getFieldTypes()[2];
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.model = new FmModelDataConverter(FmModelDataConverter.extractLabelType(super.getModelSchema())).load(list);
        this.dim = this.model.dim;
        if (this.labelType.equals(Types.INT)) {
            this.model.labelValues[0] = Integer.valueOf(Double.valueOf(this.model.labelValues[0].toString()).intValue());
            this.model.labelValues[1] = Integer.valueOf(Double.valueOf(this.model.labelValues[1].toString()).intValue());
        } else if (this.labelType.equals(Types.LONG)) {
            this.model.labelValues[0] = Long.valueOf(Double.valueOf(this.model.labelValues[0].toString()).longValue());
            this.model.labelValues[1] = Long.valueOf(Double.valueOf(this.model.labelValues[1].toString()).longValue());
        }
        if (this.vectorColIndex == -1) {
            TableSchema dataSchema = getDataSchema();
            if (this.model.featureColNames == null) {
                this.vectorColIndex = TableUtil.findColIndexWithAssert(dataSchema.getFieldNames(), this.model.vectorColName);
                return;
            }
            int length = this.model.featureColNames.length;
            this.featureIdx = new int[length];
            String[] fieldNames = dataSchema.getFieldNames();
            for (int i = 0; i < length; i++) {
                this.featureIdx[i] = TableUtil.findColIndexWithAssert(fieldNames, this.model.featureColNames[i]);
            }
            this.threadLocalVec = ThreadLocal.withInitial(() -> {
                return new DenseVector(length);
            });
        }
    }

    public double getY(SparseVector sparseVector, boolean z) {
        double doubleValue = ((Double) FmOptimizer.calcY(sparseVector, this.model.fmModel, this.dim).f0).doubleValue();
        if (z) {
            doubleValue = logit(doubleValue);
        }
        return doubleValue;
    }

    private static double logit(double d) {
        if (d < -37.0d) {
            return Criteria.INVALID_GAIN;
        }
        if (d > 34.0d) {
            return 1.0d;
        }
        return 1.0d / (1.0d + Math.exp(-d));
    }

    public FmModelData getModel() {
        return this.model;
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Object predictResult(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        return predictResultDetail(slicedSelectedSample).f0;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v50, types: [com.alibaba.alink.common.linalg.Vector] */
    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Tuple2<Object, String> predictResultDetail(Mapper.SlicedSelectedSample slicedSelectedSample) {
        DenseVector denseVector;
        if (this.vectorColIndex != -1) {
            denseVector = FeatureLabelUtil.getVectorFeature(slicedSelectedSample.get(this.vectorColIndex), false, Integer.valueOf(this.model.vectorSize));
        } else {
            denseVector = this.threadLocalVec.get();
            slicedSelectedSample.fillDenseVector(denseVector, false, this.featureIdx);
        }
        double doubleValue = ((Double) FmOptimizer.calcY(denseVector, this.model.fmModel, this.dim).f0).doubleValue();
        if (this.model.task.equals(BaseFmTrainBatchOp.Task.REGRESSION)) {
            return Tuple2.of(Double.valueOf(doubleValue), String.format("{\"%s\":%f}", "label", Double.valueOf(doubleValue)));
        }
        if (!this.model.task.equals(BaseFmTrainBatchOp.Task.BINARY_CLASSIFICATION)) {
            throw new AkUnsupportedOperationException("task not support yet");
        }
        double logit = logit(doubleValue);
        Object obj = logit <= 0.5d ? this.model.labelValues[1] : this.model.labelValues[0];
        HashMap hashMap = new HashMap(0);
        hashMap.put(this.model.labelValues[1].toString(), Double.valueOf(1.0d - logit).toString());
        hashMap.put(this.model.labelValues[0].toString(), Double.valueOf(logit).toString());
        return Tuple2.of(obj, JsonConverter.toJson(hashMap));
    }
}
