package com.alibaba.alink.operator.common.regression;

import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.regression.glm.FamilyLink;
import com.alibaba.alink.operator.common.regression.glm.GlmUtil;
import com.alibaba.alink.params.mapper.RichModelMapperParams;
import com.alibaba.alink.params.regression.GlmPredictParams;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.api.Types;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/regression/GlmModelMapper.class */
public class GlmModelMapper extends ModelMapper {
    private static final long serialVersionUID = 8193374524901551398L;
    private double[] coefficients;
    private double intercept;
    private FamilyLink familyLink;
    private int offsetColIdx;
    private int[] featureColIdxs;
    private boolean hasLinkPredit;

    public GlmModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        String str = (String) params.get(GlmPredictParams.LINK_PRED_RESULT_COL);
        this.hasLinkPredit = (str == null || str.isEmpty()) ? false : true;
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        GlmModelData load = new GlmModelDataConverter().load(list);
        this.coefficients = load.coefficients;
        this.intercept = load.intercept;
        TableSchema dataSchema = getDataSchema();
        if (load.offsetColName == null || load.offsetColName.isEmpty()) {
            this.offsetColIdx = -1;
        } else {
            this.offsetColIdx = TableUtil.findColIndex(dataSchema.getFieldNames(), load.offsetColName);
        }
        this.featureColIdxs = new int[load.featureColNames.length];
        for (int i = 0; i < this.featureColIdxs.length; i++) {
            this.featureColIdxs[i] = TableUtil.findColIndexWithAssert(dataSchema.getFieldNames(), load.featureColNames[i]);
        }
        this.familyLink = new FamilyLink(load.familyName, load.variancePower, load.linkName, load.linkPower);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        String[] strArr;
        TypeInformation[] typeInformationArr;
        String str = (String) params.get(GlmPredictParams.PREDICTION_COL);
        String str2 = (String) params.get(GlmPredictParams.LINK_PRED_RESULT_COL);
        if (str2 == null || str2.isEmpty()) {
            strArr = new String[]{str};
            typeInformationArr = new TypeInformation[]{Types.DOUBLE()};
        } else {
            strArr = new String[]{str, str2};
            typeInformationArr = new TypeInformation[]{Types.DOUBLE(), Types.DOUBLE()};
        }
        return Tuple4.of(getDataSchema().getFieldNames(), strArr, typeInformationArr, params.get(RichModelMapperParams.RESERVED_COLS));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        double doubleValue = this.offsetColIdx >= 0 ? ((Number) slicedSelectedSample.get(this.offsetColIdx)).doubleValue() : 0.0d;
        double[] dArr = new double[this.featureColIdxs.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = ((Number) slicedSelectedSample.get(this.featureColIdxs[i])).doubleValue();
        }
        double predict = GlmUtil.predict(this.coefficients, this.intercept, dArr, doubleValue, this.familyLink);
        if (!this.hasLinkPredit) {
            slicedResult.set(0, Double.valueOf(predict));
            return;
        }
        double linearPredict = GlmUtil.linearPredict(this.coefficients, this.intercept, dArr) + doubleValue;
        slicedResult.set(0, Double.valueOf(predict));
        slicedResult.set(1, Double.valueOf(linearPredict));
    }
}
