package com.alibaba.alink.operator.common.tree.predictors;

import com.alibaba.alink.common.exceptions.XGboostException;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.RichModelMapper;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.outlier.OutlierUtil;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.common.tree.XGBoostModelDataConverter;
import com.alibaba.alink.operator.common.tree.xgboost.Booster;
import com.alibaba.alink.operator.common.tree.xgboost.plugin.XGBoostClassLoaderFactory;
import com.alibaba.alink.params.shared.colname.HasVectorCol;
import com.alibaba.alink.params.xgboost.HasObjective;
import com.alibaba.alink.params.xgboost.XGBoostInputParams;
import com.alibaba.alink.params.xgboost.XGBoostLearningTaskParams;
import com.alibaba.alink.params.xgboost.XGBoostPredictParams;
import java.io.IOException;
import java.io.InputStream;
import java.util.Base64;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/predictors/XGBoostModelMapper.class */
public class XGBoostModelMapper extends RichModelMapper {
    private static final float[] INITIAL_VALUE_OF_LABEL = {0.0f};
    private final XGBoostClassLoaderFactory xgBoostClassLoaderFactory;
    private transient Booster booster;
    private transient int vectorColIndex;
    private transient Object[] labels;
    private transient int vectorSize;
    private transient HasObjective.Objective objective;
    private transient Function<Row, Row> selectFieldsFunction;

    public XGBoostModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.xgBoostClassLoaderFactory = new XGBoostClassLoaderFactory((String) params.get(XGBoostPredictParams.PLUGIN_VERSION));
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        final XGBoostModelDataConverter xGBoostModelDataConverter = new XGBoostModelDataConverter();
        xGBoostModelDataConverter.load(list);
        this.labels = xGBoostModelDataConverter.labels;
        try {
            this.booster = XGBoostClassLoaderFactory.create(this.xgBoostClassLoaderFactory).create().loadModel(new InputStream() { // from class: com.alibaba.alink.operator.common.tree.predictors.XGBoostModelMapper.1
                private final Iterator<String> iterator;
                private byte[] buffer;
                private final Base64.Decoder base64Encoder = Base64.getDecoder();
                private int cursor = 0;

                {
                    this.iterator = xGBoostModelDataConverter.modelData.iterator();
                }

                @Override // java.io.InputStream
                public int read() {
                    if ((this.buffer == null || this.cursor >= this.buffer.length) && this.iterator.hasNext()) {
                        this.buffer = this.base64Encoder.decode(this.iterator.next());
                        this.cursor = 0;
                    }
                    if (this.buffer == null || this.cursor >= this.buffer.length) {
                        return -1;
                    }
                    byte[] bArr = this.buffer;
                    int i = this.cursor;
                    this.cursor = i + 1;
                    return bArr[i] & 255;
                }
            });
            this.vectorSize = ((Integer) xGBoostModelDataConverter.meta.get(XGBoostModelDataConverter.XGBOOST_VECTOR_SIZE)).intValue();
            this.objective = (HasObjective.Objective) xGBoostModelDataConverter.meta.get(XGBoostLearningTaskParams.OBJECTIVE);
            if (xGBoostModelDataConverter.meta.contains(XGBoostInputParams.VECTOR_COL)) {
                int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(getDataSchema(), (String) xGBoostModelDataConverter.meta.get(HasVectorCol.VECTOR_COL));
                this.selectFieldsFunction = row -> {
                    return Row.of(new Object[]{row.getField(findColIndexWithAssertAndHint)});
                };
            } else {
                int[] findColIndicesWithAssertAndHint = TableUtil.findColIndicesWithAssertAndHint(getDataSchema(), OutlierUtil.uniformFeatureColsDefaultAsAll((String[]) xGBoostModelDataConverter.meta.get(XGBoostInputParams.FEATURE_COLS), TableUtil.getNumericCols(getDataSchema())));
                this.selectFieldsFunction = row2 -> {
                    return Row.of(new Object[]{OutlierUtil.rowToDenseVector(row2, findColIndicesWithAssertAndHint, findColIndicesWithAssertAndHint.length)});
                };
            }
            this.vectorColIndex = 0;
        } catch (XGboostException | IOException e) {
            throw new IllegalStateException(e);
        }
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Object predictResult(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        return predictResultDetail(slicedSelectedSample).f0;
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Tuple2<Object, String> predictResultDetail(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        Row row = new Row(slicedSelectedSample.length());
        slicedSelectedSample.fillRow(row);
        float[] predict = this.booster.predict(row, this.selectFieldsFunction, row2 -> {
            Vector vector = VectorUtil.getVector(row2.getField(this.vectorColIndex));
            if ((vector instanceof SparseVector) && vector.size() < 0) {
                ((SparseVector) vector).setSize(this.vectorSize);
            }
            return Tuple2.of(vector, INITIAL_VALUE_OF_LABEL);
        });
        if (predict == null || predict.length <= 0) {
            return Tuple2.of((Object) null, (Object) null);
        }
        switch (this.objective) {
            case BINARY_LOGISTIC:
                HashMap hashMap = new HashMap();
                hashMap.put(this.labels[0], Double.valueOf(1.0d - predict[0]));
                hashMap.put(this.labels[1], Double.valueOf(predict[0]));
                return Tuple2.of(predict[0] > 0.5f ? this.labels[1] : this.labels[0], JsonConverter.toJson(hashMap));
            case BINARY_HINGE:
                HashMap hashMap2 = new HashMap();
                hashMap2.put(this.labels[0], Double.valueOf(((double) predict[0]) == 1.0d ? Criteria.INVALID_GAIN : 1.0d));
                hashMap2.put(this.labels[1], Double.valueOf(((double) predict[0]) == 1.0d ? 1.0d : Criteria.INVALID_GAIN));
                return Tuple2.of(((double) predict[0]) == 1.0d ? this.labels[1] : this.labels[0], JsonConverter.toJson(hashMap2));
            case MULTI_SOFTMAX:
                HashMap hashMap3 = new HashMap();
                for (int i = 0; i < this.labels.length; i++) {
                    if (predict[0] == i) {
                        hashMap3.put(this.labels[i], Double.valueOf(1.0d));
                    } else {
                        hashMap3.put(this.labels[i], Double.valueOf(Criteria.INVALID_GAIN));
                    }
                }
                return Tuple2.of(this.labels[(int) predict[0]], JsonConverter.toJson(hashMap3));
            case MULTI_SOFTPROB:
                HashMap hashMap4 = new HashMap();
                double d = 0.0d;
                int i2 = 0;
                for (int i3 = 0; i3 < this.labels.length; i3++) {
                    hashMap4.put(this.labels[i3], Double.valueOf(predict[i3]));
                    if (d > predict[i3]) {
                        d = predict[i3];
                        i2 = i3;
                    }
                }
                return Tuple2.of(this.labels[i2], JsonConverter.toJson(hashMap4));
            default:
                return Tuple2.of(Double.valueOf(predict[0]), (Object) null);
        }
    }
}
