package com.alibaba.alink.operator.common.linear;

import com.alibaba.alink.common.exceptions.AkUnimplementedOperationException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.RichModelMapper;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.classification.LinearModelMapperParams;
import java.util.HashMap;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/linear/LinearModelMapper.class */
public class LinearModelMapper extends RichModelMapper {
    private static final long serialVersionUID = -1820786486066749971L;
    private int vectorColIndex;
    private LinearModelData model;
    private int[] featureIdx;
    private int featureN;
    private String vectorColName;
    private transient ThreadLocal<DenseVector> threadLocalVec;

    public LinearModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.vectorColIndex = -1;
        if (null != params) {
            this.vectorColName = (String) params.get(LinearModelMapperParams.VECTOR_COL);
            if (null == this.vectorColName || this.vectorColName.length() == 0) {
                return;
            }
            this.vectorColIndex = TableUtil.findColIndexWithAssert(tableSchema2.getFieldNames(), this.vectorColName);
        }
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.model = new LinearModelDataConverter(LinearModelDataConverter.extractLabelType(super.getModelSchema())).load(list);
        if (this.vectorColIndex == -1) {
            TableSchema dataSchema = getDataSchema();
            if (this.model.featureNames == null) {
                this.vectorColName = this.model.vectorColName;
                this.vectorColIndex = TableUtil.findColIndexWithAssert(dataSchema.getFieldNames(), this.vectorColName);
                this.threadLocalVec = ThreadLocal.withInitial(() -> {
                    return new DenseVector(this.model.vectorSize + (this.model.hasInterceptItem ? 1 : 0));
                });
                return;
            }
            this.featureN = this.model.featureNames.length;
            this.featureIdx = new int[this.featureN];
            String[] fieldNames = dataSchema.getFieldNames();
            for (int i = 0; i < this.featureN; i++) {
                this.featureIdx[i] = TableUtil.findColIndexWithAssert(fieldNames, this.model.featureNames[i]);
            }
            this.threadLocalVec = ThreadLocal.withInitial(() -> {
                return new DenseVector(this.featureN + (this.model.hasInterceptItem ? 1 : 0));
            });
        }
    }

    public void loadModel(LinearModelData linearModelData) {
        this.model = new LinearModelData(linearModelData);
        if (this.vectorColIndex == -1) {
            TableSchema dataSchema = getDataSchema();
            if (linearModelData.featureNames != null) {
                this.featureN = this.model.featureNames.length;
                this.featureIdx = new int[this.featureN];
                String[] fieldNames = dataSchema.getFieldNames();
                for (int i = 0; i < this.featureN; i++) {
                    this.featureIdx[i] = TableUtil.findColIndexWithAssert(fieldNames, this.model.featureNames[i]);
                }
            } else {
                this.vectorColIndex = TableUtil.findColIndexWithAssert(dataSchema.getFieldNames(), this.model.vectorColName);
            }
        }
        this.threadLocalVec = ThreadLocal.withInitial(() -> {
            return new DenseVector(this.model.vectorSize + (this.model.hasInterceptItem ? 1 : 0));
        });
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Object predictResult(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        return predictResultDetail(slicedSelectedSample).f0;
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Tuple2<Object, String> predictResultDetail(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        Vector vector;
        Object obj;
        String str = null;
        if (this.vectorColIndex != -1) {
            vector = FeatureLabelUtil.getVectorFeature(slicedSelectedSample.get(this.vectorColIndex), this.model.hasInterceptItem, Integer.valueOf(this.model.vectorSize));
        } else {
            vector = this.threadLocalVec.get();
            slicedSelectedSample.fillDenseVector((DenseVector) vector, this.model.hasInterceptItem, this.featureIdx);
        }
        if (this.model.linearModelType == LinearModelType.LR || this.model.linearModelType == LinearModelType.SVM) {
            Tuple2<Object, Double[]> predictWithProb = predictWithProb(vector);
            obj = predictWithProb.f0;
            HashMap hashMap = new HashMap(0);
            int length = this.model.labelValues.length;
            for (int i = 0; i < length; i++) {
                hashMap.put(this.model.labelValues[i].toString(), ((Double[]) predictWithProb.f1)[i].toString());
            }
            str = JsonConverter.toJson(hashMap);
        } else {
            obj = predict(vector);
        }
        return new Tuple2<>(obj, str);
    }

    public Object predict(Vector vector) throws Exception {
        double dot = FeatureLabelUtil.dot(vector, this.model.coefVector);
        switch (this.model.linearModelType) {
            case LR:
            case SVM:
            case Perceptron:
                return dot >= Criteria.INVALID_GAIN ? this.model.labelValues[0] : this.model.labelValues[1];
            case LinearReg:
            case SVR:
                return Double.valueOf(dot);
            default:
                throw new AkUnimplementedOperationException("Linear model type is Not implemented yet!");
        }
    }

    public Tuple2<Object, Double[]> predictWithProb(Vector vector) {
        double dot = FeatureLabelUtil.dot(vector, this.model.coefVector);
        switch (this.model.linearModelType) {
            case LR:
            case SVM:
                double sigmoid = sigmoid(dot);
                return new Tuple2<>(dot >= Criteria.INVALID_GAIN ? this.model.labelValues[0] : this.model.labelValues[1], new Double[]{Double.valueOf(sigmoid), Double.valueOf(1.0d - sigmoid)});
            default:
                throw new AkUnimplementedOperationException("Current linear algo not supports score or detail yet!");
        }
    }

    private double sigmoid(double d) {
        return 1.0d - (1.0d / (1.0d + Math.exp(d)));
    }
}
