package com.alibaba.alink.operator.common.finance;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.MatVecOp;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.dataproc.vector.VectorAssemblerMapper;
import com.alibaba.alink.operator.common.linear.LinearModelData;
import com.alibaba.alink.operator.common.linear.LinearModelDataConverter;
import com.alibaba.alink.params.classification.LinearModelMapperParams;
import com.alibaba.alink.params.finance.ScorePredictParams;
import com.alibaba.alink.params.mapper.RichModelMapperParams;
import com.alibaba.alink.params.shared.colname.HasFeatureCols;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/finance/ScorePredictMapper.class */
public class ScorePredictMapper extends ModelMapper {
    private static final long serialVersionUID = -6096135125528711852L;
    private int vectorColIndex;
    private LinearModelData model;
    private int[] featureIdx;
    private int featureN;
    private boolean calculateScore;
    private boolean calculateScorePerFeature;
    private boolean calculateDetail;
    private int[] selectedIndices;

    public ScorePredictMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        String str;
        this.vectorColIndex = -1;
        if (null == params || null == (str = (String) params.get(LinearModelMapperParams.VECTOR_COL)) || str.length() == 0) {
            return;
        }
        this.vectorColIndex = TableUtil.findColIndexWithAssertAndHint(tableSchema2.getFieldNames(), str);
    }

    protected Double predictScore(Vector vector) {
        return Double.valueOf(MatVecOp.dot(vector, this.model.coefVector));
    }

    protected double[] predictBinsScore(Vector vector) {
        if (!(vector instanceof SparseVector)) {
            double[] data = ((DenseVector) vector).getData();
            int length = data.length;
            double[] dArr = new double[length];
            for (int i = 0; i < length; i++) {
                dArr[i] = data[i] * this.model.coefVector.get(i);
            }
            return dArr;
        }
        double[] values = ((SparseVector) vector).getValues();
        int[] indices = ((SparseVector) vector).getIndices();
        int length2 = values.length;
        double[] dArr2 = new double[length2];
        for (int i2 = 0; i2 < length2; i2++) {
            dArr2[i2] = values[i2] * this.model.coefVector.get(indices[i2]);
        }
        return dArr2;
    }

    protected Object predictResult(Vector vector) throws Exception {
        double exp = Math.exp(-MatVecOp.dot(vector, this.model.coefVector));
        return Double.valueOf((Double.isNaN(exp) || Double.isInfinite(exp)) ? 0.0d : 1.0d / (1.0d + exp));
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.model = new LinearModelDataConverter(LinearModelDataConverter.extractLabelType(super.getModelSchema())).load(list);
        if (this.vectorColIndex == -1) {
            TableSchema dataSchema = getDataSchema();
            if (this.model.featureNames == null) {
                this.vectorColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), this.model.vectorColName);
                return;
            }
            this.featureN = this.model.featureNames.length;
            this.featureIdx = new int[this.featureN];
            String[] fieldNames = dataSchema.getFieldNames();
            for (int i = 0; i < this.featureN; i++) {
                this.featureIdx[i] = TableUtil.findColIndexWithAssertAndHint(fieldNames, this.model.featureNames[i]);
            }
            this.selectedIndices = TableUtil.findColIndicesWithAssertAndHint((String[]) this.params.get(HasFeatureCols.FEATURE_COLS), this.model.featureNames);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        double[] predictBinsScore = predictBinsScore(getFeatureVector(slicedSelectedSample, this.model.hasInterceptItem, this.featureN, this.featureIdx, this.vectorColIndex, Integer.valueOf(this.model.vectorSize)));
        double d = 0.0d;
        for (double d2 : predictBinsScore) {
            d += d2;
        }
        int i = 0;
        if (this.calculateDetail) {
            slicedResult.set(0, predictResultDetail(d));
            i = 0 + 1;
        }
        if (this.calculateScore) {
            slicedResult.set(i, Double.valueOf(d));
            i++;
        }
        if (this.calculateScorePerFeature) {
            for (int i2 = 1; i2 < predictBinsScore.length; i2++) {
                if (this.selectedIndices != null) {
                    slicedResult.set(i + this.selectedIndices[i2 - 1], Double.valueOf(predictBinsScore[i2]));
                } else {
                    slicedResult.set((i + i2) - 1, Double.valueOf(predictBinsScore[i2]));
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        String[] strArr = (String[]) params.get(RichModelMapperParams.RESERVED_COLS);
        this.calculateScore = params.contains(ScorePredictParams.PREDICTION_SCORE_COL);
        this.calculateDetail = params.contains(ScorePredictParams.PREDICTION_DETAIL_COL);
        this.calculateScorePerFeature = ((Boolean) params.get(ScorePredictParams.CALCULATE_SCORE_PER_FEATURE)).booleanValue();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        if (this.calculateDetail) {
            arrayList.add((String) params.get(ScorePredictParams.PREDICTION_DETAIL_COL));
            arrayList2.add(Types.STRING);
        }
        if (this.calculateScore) {
            arrayList.add((String) params.get(ScorePredictParams.PREDICTION_SCORE_COL));
            arrayList2.add(Types.DOUBLE);
        }
        if (this.calculateScorePerFeature) {
            String[] strArr2 = (String[]) params.get(ScorePredictParams.PREDICTION_SCORE_PER_FEATURE_COLS);
            arrayList.addAll(Arrays.asList(strArr2));
            for (String str : strArr2) {
                arrayList2.add(Types.DOUBLE);
            }
        }
        return Tuple4.of(getDataSchema().getFieldNames(), arrayList.toArray(new String[0]), arrayList2.toArray(new TypeInformation[0]), strArr);
    }

    protected String predictResultDetail(double d) {
        Double[] predictWithProb = predictWithProb(d);
        HashMap hashMap = new HashMap(1);
        int length = this.model.labelValues.length;
        for (int i = 0; i < length; i++) {
            hashMap.put(this.model.labelValues[i].toString(), predictWithProb[i].toString());
        }
        return JsonConverter.toJson(hashMap);
    }

    private Double[] predictWithProb(double d) {
        double sigmoid = sigmoid(d);
        return new Double[]{Double.valueOf(sigmoid), Double.valueOf(1.0d - sigmoid)};
    }

    private double sigmoid(double d) {
        return 1.0d - (1.0d / (1.0d + Math.exp(d)));
    }

    private static Vector getFeatureVector(Mapper.SlicedSelectedSample slicedSelectedSample, boolean z, int i, int[] iArr, int i2, Integer num) {
        Vector vector;
        if (i2 != -1) {
            Vector vector2 = VectorUtil.getVector(slicedSelectedSample.get(i2));
            if (vector2 instanceof SparseVector) {
                SparseVector sparseVector = (SparseVector) vector2;
                if (null != num) {
                    sparseVector.setSize(num.intValue());
                }
                vector = z ? sparseVector.prefix(1.0d) : sparseVector;
            } else {
                DenseVector denseVector = (DenseVector) vector2;
                vector = z ? denseVector.prefix(1.0d) : denseVector;
            }
        } else if (z) {
            Object[] objArr = new Object[i + 1];
            objArr[0] = Double.valueOf(1.0d);
            for (int i3 = 0; i3 < i; i3++) {
                objArr[1 + i3] = VectorUtil.getVector(slicedSelectedSample.get(iArr[i3]));
            }
            vector = (Vector) VectorAssemblerMapper.assembler(objArr);
        } else {
            Object[] objArr2 = new Object[i];
            for (int i4 = 0; i4 < i; i4++) {
                objArr2[i4] = VectorUtil.getVector(slicedSelectedSample.get(iArr[i4]));
            }
            vector = (Vector) VectorAssemblerMapper.assembler(objArr2);
        }
        return vector;
    }
}
