package com.alibaba.alink.common.mapper;

import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.params.mapper.RichModelMapperParams;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

/* loaded from: input_file:com/alibaba/alink/common/mapper/RichModelMapper.class */
public abstract class RichModelMapper extends ModelMapper {
    private static final long serialVersionUID = -6722995426402759862L;
    private final boolean isPredDetail;

    public RichModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.isPredDetail = params.contains(RichModelMapperParams.PREDICTION_DETAIL_COL);
    }

    protected TypeInformation<?> initPredResultColType(TableSchema tableSchema) {
        return tableSchema.getFieldTypes()[2];
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected final Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        String[] strArr;
        TypeInformation[] typeInformationArr;
        String[] fieldNames = tableSchema2.getFieldNames();
        TypeInformation<?> initPredResultColType = initPredResultColType(tableSchema);
        String str = (String) params.get(RichModelMapperParams.PREDICTION_COL);
        if (params.contains(RichModelMapperParams.PREDICTION_DETAIL_COL)) {
            strArr = new String[]{str, (String) params.get(RichModelMapperParams.PREDICTION_DETAIL_COL)};
            typeInformationArr = new TypeInformation[]{initPredResultColType, Types.STRING};
        } else {
            strArr = new String[]{str};
            typeInformationArr = new TypeInformation[]{initPredResultColType};
        }
        return Tuple4.of(fieldNames, strArr, typeInformationArr, (String[]) params.get(RichModelMapperParams.RESERVED_COLS));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public final void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        if (!this.isPredDetail) {
            slicedResult.set(0, predictResult(slicedSelectedSample));
            return;
        }
        Tuple2<Object, String> predictResultDetail = predictResultDetail(slicedSelectedSample);
        slicedResult.set(0, predictResultDetail.f0);
        slicedResult.set(1, predictResultDetail.f1);
    }

    protected abstract Object predictResult(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception;

    protected abstract Tuple2<Object, String> predictResultDetail(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception;
}
