package com.alibaba.alink.operator.common.regression;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.RichModelMapper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.linear.AftRegObjFunc;
import com.alibaba.alink.operator.common.linear.FeatureLabelUtil;
import com.alibaba.alink.operator.common.linear.LinearModelData;
import com.alibaba.alink.operator.common.linear.LinearModelDataConverter;
import com.alibaba.alink.params.classification.LinearModelMapperParams;
import com.alibaba.alink.params.regression.AftRegPredictParams;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/regression/AFTModelMapper.class */
public class AFTModelMapper extends RichModelMapper {
    private static final long serialVersionUID = 984867877738156476L;
    private int vectorColIndex;
    private final double[] quantileProbabilities;
    private LinearModelData model;
    private int[] featureIdx;
    private transient ThreadLocal<DenseVector> threadLocalVec;

    public AFTModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.vectorColIndex = -1;
        this.quantileProbabilities = (double[]) params.get(AftRegPredictParams.QUANTILE_PROBABILITIES);
        String str = (String) params.get(LinearModelMapperParams.VECTOR_COL);
        if (null == str || str.length() == 0) {
            return;
        }
        this.vectorColIndex = TableUtil.findColIndexWithAssert(tableSchema2.getFieldNames(), str);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.model = new LinearModelDataConverter(LinearModelDataConverter.extractLabelType(super.getModelSchema())).load(list);
        if (this.vectorColIndex == -1) {
            TableSchema dataSchema = getDataSchema();
            if (this.model.featureNames == null) {
                this.vectorColIndex = TableUtil.findColIndexWithAssert(dataSchema.getFieldNames(), this.model.vectorColName);
                return;
            }
            int length = this.model.featureNames.length;
            this.featureIdx = new int[length];
            this.threadLocalVec = ThreadLocal.withInitial(() -> {
                return new DenseVector(length + (this.model.hasInterceptItem ? 1 : 0));
            });
            String[] fieldNames = dataSchema.getFieldNames();
            for (int i = 0; i < length; i++) {
                this.featureIdx[i] = TableUtil.findColIndexWithAssert(fieldNames, this.model.featureNames[i]);
            }
        }
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Object predictResult(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        return predictResultDetail(slicedSelectedSample).f0;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v31, types: [com.alibaba.alink.common.linalg.Vector] */
    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Tuple2<Object, String> predictResultDetail(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        DenseVector denseVector;
        if (this.vectorColIndex != -1) {
            denseVector = FeatureLabelUtil.getVectorFeature(slicedSelectedSample.get(this.vectorColIndex), this.model.hasInterceptItem, Integer.valueOf(this.model.vectorSize));
        } else {
            denseVector = this.threadLocalVec.get();
            slicedSelectedSample.fillDenseVector(denseVector, this.model.hasInterceptItem, this.featureIdx);
        }
        double[] data = this.model.coefVector.getData();
        double d = data[data.length - 1];
        double[] dArr = new double[this.quantileProbabilities.length];
        double exp = Math.exp(AftRegObjFunc.getDotProduct(denseVector, this.model.coefVector));
        if (exp == Double.POSITIVE_INFINITY) {
            exp = Double.MAX_VALUE;
        }
        for (int i = 0; i < this.quantileProbabilities.length; i++) {
            dArr[i] = exp * Math.exp(Math.log(-Math.log(1.0d - this.quantileProbabilities[i])) * d);
        }
        return Tuple2.of(Double.valueOf(exp), VectorUtil.serialize(new DenseVector(dArr)));
    }
}
