package com.alibaba.alink.operator.common.finance.stepwiseSelector;

import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.dataproc.vector.VectorAssemblerMapper;
import com.alibaba.alink.operator.common.feature.SelectorModelData;
import com.alibaba.alink.operator.common.feature.SelectorModelDataConverter;
import com.alibaba.alink.params.finance.SelectorPredictParams;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/finance/stepwiseSelector/SelectorModelMapper.class */
public class SelectorModelMapper extends ModelMapper {
    private static final long serialVersionUID = -4884089344356950010L;
    private SelectorModelData smd;
    private int[] selectedIndices;
    private int selectedIdx;

    public SelectorModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.smd = new SelectorModelDataConverter().load(list);
        if (this.smd.vectorColNames != null) {
            this.selectedIndices = TableUtil.findColIndicesWithAssert(getDataSchema().getFieldNames(), this.smd.vectorColNames);
            return;
        }
        String str = this.smd.vectorColName;
        if (this.params.contains(SelectorPredictParams.SELECTED_COL)) {
            str = (String) this.params.get(SelectorPredictParams.SELECTED_COL);
        }
        this.selectedIdx = TableUtil.findColIndexWithAssert(getDataSchema().getFieldNames(), str);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        return Tuple4.of(tableSchema2.getFieldNames(), new String[]{(String) params.get(SelectorPredictParams.PREDICTION_COL)}, new TypeInformation[]{AlinkTypes.VECTOR}, params.get(SelectorPredictParams.RESERVED_COLS));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) {
        Vector slice;
        if (this.smd.vectorColNames != null) {
            Object[] objArr = new Object[this.smd.vectorColNames.length];
            for (int i = 0; i < this.smd.vectorColNames.length; i++) {
                objArr[i] = VectorUtil.getVector(slicedSelectedSample.get(this.selectedIndices[i]));
            }
            slice = (Vector) VectorAssemblerMapper.assembler(objArr);
        } else {
            slice = VectorUtil.getVector(slicedSelectedSample.get(this.selectedIdx)).slice(this.smd.selectedIndices);
        }
        slicedResult.set(0, slice);
    }
}
