package com.alibaba.alink.operator.common.classification.ann;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.RichModelMapper;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.classification.MultilayerPerceptronPredictParams;
import com.alibaba.alink.params.classification.MultilayerPerceptronTrainParams;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/classification/ann/MlpcModelMapper.class */
public class MlpcModelMapper extends RichModelMapper {
    private static final long serialVersionUID = 2691422221337359053L;
    private boolean isVectorInput;
    private int vectorColIdx;
    private int[] featureColIdx;
    private transient TopologyModel topo;
    private transient List<Object> labels;
    String[] featureColNames;
    private transient ThreadLocal<DenseVector> threadLocalVec;

    public MlpcModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
    }

    private void getFeaturesVector(Mapper.SlicedSelectedSample slicedSelectedSample, boolean z, int[] iArr, int i, DenseVector denseVector) {
        if (!z) {
            int length = iArr.length;
            for (int i2 = 0; i2 < length; i2++) {
                denseVector.set(i2, ((Number) slicedSelectedSample.get(i2)).doubleValue());
            }
            return;
        }
        Vector vector = VectorUtil.getVector(slicedSelectedSample.get(i));
        if (null == vector) {
            Arrays.fill(denseVector.getData(), Criteria.INVALID_GAIN);
            return;
        }
        if (vector instanceof DenseVector) {
            denseVector.setData(((DenseVector) vector).getData());
            return;
        }
        Arrays.fill(denseVector.getData(), Criteria.INVALID_GAIN);
        int[] indices = ((SparseVector) vector).getIndices();
        double[] values = ((SparseVector) vector).getValues();
        for (int i3 = 0; i3 < indices.length; i3++) {
            denseVector.set(indices[i3], values[i3]);
        }
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        MlpcModelData load = new MlpcModelDataConverter().load(list);
        load.labelType = super.getModelSchema().getFieldTypes()[2];
        this.labels = load.labels;
        int[] iArr = (int[]) load.meta.get(MultilayerPerceptronTrainParams.LAYERS);
        this.topo = FeedForwardTopology.multiLayerPerceptron(iArr, true).getModel(load.weights);
        this.isVectorInput = ((Boolean) load.meta.get(ModelParamName.IS_VECTOR_INPUT)).booleanValue();
        TableSchema dataSchema = getDataSchema();
        if (this.isVectorInput) {
            this.vectorColIdx = TableUtil.findColIndexWithAssert(dataSchema.getFieldNames(), this.params.contains(MultilayerPerceptronPredictParams.VECTOR_COL) ? (String) this.params.get(MultilayerPerceptronPredictParams.VECTOR_COL) : (String) load.meta.get(MultilayerPerceptronPredictParams.VECTOR_COL));
        } else {
            this.featureColNames = (String[]) load.meta.get(MultilayerPerceptronTrainParams.FEATURE_COLS);
            this.featureColIdx = TableUtil.findColIndicesWithAssert(dataSchema.getFieldNames(), this.featureColNames);
        }
        this.threadLocalVec = ThreadLocal.withInitial(() -> {
            return new DenseVector(iArr[0]);
        });
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Tuple2<Object, String> predictResultDetail(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        DenseVector denseVector = this.threadLocalVec.get();
        getFeaturesVector(slicedSelectedSample, this.isVectorInput, this.featureColIdx, this.vectorColIdx, denseVector);
        DenseVector predict = this.topo.predict(denseVector);
        HashMap hashMap = new HashMap(predict.size());
        int i = -1;
        double d = 0.0d;
        for (int i2 = 0; i2 < predict.size(); i2++) {
            if (predict.get(i2) > d) {
                i = i2;
                d = predict.get(i2);
            }
        }
        for (int i3 = 0; i3 < predict.size(); i3++) {
            hashMap.put((Comparable) this.labels.get(i3), Double.valueOf(predict.get(i3)));
        }
        return Tuple2.of(this.labels.get(i), JsonConverter.toJson(hashMap));
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Object predictResult(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        return predictResultDetail(slicedSelectedSample).f0;
    }
}
