package com.alibaba.alink.operator.common.nlp.bert;

import com.alibaba.alink.common.dl.BertResources;
import com.alibaba.alink.common.dl.plugin.TFPredictorClassLoaderFactory;
import com.alibaba.alink.common.io.plugin.ResourcePluginFactory;
import com.alibaba.alink.common.mapper.ComboMapper;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.nlp.bert.tokenizer.EncodingKeys;
import com.alibaba.alink.operator.common.tensorflow.TFSavedModelPredictRowMapper;
import com.alibaba.alink.params.shared.colname.HasReservedColsDefaultAsNull;
import com.alibaba.alink.params.tensorflow.bert.BertTextEmbeddingParams;
import com.alibaba.alink.params.tensorflow.bert.HasHiddenStatesCol;
import com.alibaba.alink.params.tensorflow.bert.HasLengthCol;
import com.alibaba.alink.params.tensorflow.bert.HasMaxSeqLength;
import com.alibaba.alink.params.tensorflow.bert.HasMaxSeqLengthDefaultAsNull;
import com.alibaba.alink.params.tensorflow.savedmodel.HasOutputBatchAxes;
import com.alibaba.alink.params.tensorflow.savedmodel.TFSavedModelPredictParams;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

/* loaded from: input_file:com/alibaba/alink/operator/common/nlp/bert/BertTextEmbeddingMapper.class */
public class BertTextEmbeddingMapper extends ComboMapper {
    private final TFPredictorClassLoaderFactory factory;
    private final ResourcePluginFactory resourceFactory;
    private static final String[] MODEL_INPUTS = {EncodingKeys.INPUT_IDS_KEY.label, EncodingKeys.TOKEN_TYPE_IDS_KEY.label, EncodingKeys.ATTENTION_MASK_KEY.label};
    private static final String HIDDEN_STATES_COL = "hidden_states";
    private static final String[] MODEL_OUTPUTS = {HIDDEN_STATES_COL};

    public BertTextEmbeddingMapper(TableSchema tableSchema, Params params) {
        this(tableSchema, params, new TFPredictorClassLoaderFactory());
    }

    public BertTextEmbeddingMapper(TableSchema tableSchema, Params params, TFPredictorClassLoaderFactory tFPredictorClassLoaderFactory) {
        super(tableSchema, params);
        this.factory = tFPredictorClassLoaderFactory;
        this.resourceFactory = new ResourcePluginFactory();
    }

    @Override // com.alibaba.alink.common.mapper.ComboMapper
    public List<Mapper> getLoadedMapperList() {
        String[] fieldNames = this.params.contains(BertTextEmbeddingParams.RESERVED_COLS) ? (String[]) this.params.get(BertTextEmbeddingParams.RESERVED_COLS) : getDataSchema().getFieldNames();
        Params m1495clone = this.params.m1495clone();
        m1495clone.set((ParamInfo<ParamInfo<String[]>>) HasReservedColsDefaultAsNull.RESERVED_COLS, (ParamInfo<String[]>) fieldNames);
        m1495clone.set((ParamInfo<ParamInfo<Integer>>) HasMaxSeqLengthDefaultAsNull.MAX_SEQ_LENGTH, (ParamInfo<Integer>) this.params.get(HasMaxSeqLength.MAX_SEQ_LENGTH));
        BertTokenizerMapper bertTokenizerMapper = new BertTokenizerMapper(getDataSchema(), m1495clone, this.resourceFactory);
        Params m1495clone2 = this.params.m1495clone();
        if (!m1495clone2.contains(TFSavedModelPredictParams.MODEL_PATH)) {
            m1495clone2.set((ParamInfo<ParamInfo<String>>) TFSavedModelPredictParams.MODEL_PATH, (ParamInfo<String>) BertResources.getBertSavedModel(this.resourceFactory, (String) m1495clone2.get(BertTextEmbeddingParams.BERT_MODEL_NAME)));
        }
        m1495clone2.set((ParamInfo<ParamInfo<String[]>>) TFSavedModelPredictParams.SELECTED_COLS, (ParamInfo<String[]>) Arrays.stream(MODEL_INPUTS).map(PreTrainedTokenizerMapper::prependPrefix).toArray(i -> {
            return new String[i];
        }));
        m1495clone2.set((ParamInfo<ParamInfo<String[]>>) TFSavedModelPredictParams.INPUT_SIGNATURE_DEFS, (ParamInfo<String[]>) MODEL_INPUTS);
        m1495clone2.set((ParamInfo<ParamInfo<String>>) TFSavedModelPredictParams.OUTPUT_SCHEMA_STR, (ParamInfo<String>) TableUtil.schema2SchemaStr(TableSchema.builder().field(PreTrainedTokenizerMapper.prependPrefix(HIDDEN_STATES_COL), AlinkTypes.FLOAT_TENSOR).build()));
        m1495clone2.set((ParamInfo<ParamInfo<String[]>>) TFSavedModelPredictParams.OUTPUT_SIGNATURE_DEFS, (ParamInfo<String[]>) MODEL_OUTPUTS);
        m1495clone2.set((ParamInfo<ParamInfo<String[]>>) TFSavedModelPredictParams.RESERVED_COLS, (ParamInfo<String[]>) ArrayUtils.add(fieldNames, PreTrainedTokenizerMapper.prependPrefix(EncodingKeys.LENGTH_KEY.label)));
        m1495clone2.set((ParamInfo<ParamInfo<int[]>>) HasOutputBatchAxes.OUTPUT_BATCH_AXES, (ParamInfo<int[]>) new int[]{1});
        TFSavedModelPredictRowMapper tFSavedModelPredictRowMapper = new TFSavedModelPredictRowMapper(bertTokenizerMapper.getOutputSchema(), m1495clone2, this.factory);
        Params m1495clone3 = this.params.m1495clone();
        m1495clone3.set((ParamInfo<ParamInfo<String>>) HasHiddenStatesCol.HIDDEN_STATES_COL, (ParamInfo<String>) PreTrainedTokenizerMapper.prependPrefix(HIDDEN_STATES_COL));
        m1495clone3.set((ParamInfo<ParamInfo<String>>) HasLengthCol.LENGTH_COL, (ParamInfo<String>) PreTrainedTokenizerMapper.prependPrefix(EncodingKeys.LENGTH_KEY.label));
        m1495clone3.set((ParamInfo<ParamInfo<String[]>>) HasReservedColsDefaultAsNull.RESERVED_COLS, (ParamInfo<String[]>) fieldNames);
        return Arrays.asList(bertTokenizerMapper, tFSavedModelPredictRowMapper, new BertEmbeddingExtractorMapper(tFSavedModelPredictRowMapper.getOutputSchema(), m1495clone3));
    }

    @Override // com.alibaba.alink.common.mapper.Mapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, Params params) {
        String str = (String) params.get(BertTextEmbeddingParams.SELECTED_COL);
        String str2 = (String) params.get(BertTextEmbeddingParams.OUTPUT_COL);
        return Tuple4.of(new String[]{str}, new String[]{str2}, new TypeInformation[]{Types.STRING}, (String[]) params.get(BertTextEmbeddingParams.RESERVED_COLS));
    }
}
