package com.alibaba.alink.operator.common.nlp.bert;

import com.alibaba.alink.common.linalg.tensor.IntTensor;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.operator.common.nlp.bert.tokenizer.EncodingKeys;
import com.alibaba.alink.operator.common.nlp.bert.tokenizer.Kwargs;
import com.alibaba.alink.operator.common.nlp.bert.tokenizer.PreTrainedTokenizer;
import com.alibaba.alink.operator.common.nlp.bert.tokenizer.SingleEncoding;
import com.alibaba.alink.params.shared.colname.HasReservedColsDefaultAsNull;
import com.alibaba.alink.params.tensorflow.bert.HasTextCol;
import com.alibaba.alink.params.tensorflow.bert.HasTextPairCol;
import java.util.Arrays;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:com/alibaba/alink/operator/common/nlp/bert/PreTrainedTokenizerMapper.class */
public abstract class PreTrainedTokenizerMapper extends Mapper {
    private static final String SAFE_PREFIX = "alink_tokenizer_";
    final String textCol;
    final String textPairCol;
    final EncodingKeys[] outputKeys;
    protected PreTrainedTokenizer tokenizer;
    protected Kwargs encodeConfig;

    public PreTrainedTokenizerMapper(TableSchema tableSchema, Params params) {
        super(tableSchema, params);
        this.encodeConfig = Kwargs.empty();
        this.textCol = (String) params.get(HasTextCol.TEXT_COL);
        this.textPairCol = params.contains(HasTextPairCol.TEXT_PAIR_COL) ? (String) params.get(HasTextPairCol.TEXT_PAIR_COL) : null;
        this.outputKeys = calcOutputKeys(params);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        SingleEncoding encodePlus = this.tokenizer.encodePlus((String) slicedSelectedSample.get(0), null != this.textPairCol ? (String) slicedSelectedSample.get(1) : null, this.encodeConfig);
        for (int i = 0; i < this.outputKeys.length; i++) {
            slicedResult.set(i, new IntTensor(encodePlus.get(this.outputKeys[i])));
        }
    }

    protected String[] calcSelectedCols(Params params) {
        return params.contains(HasTextPairCol.TEXT_PAIR_COL) ? new String[]{(String) params.get(HasTextCol.TEXT_COL), (String) params.get(HasTextPairCol.TEXT_PAIR_COL)} : new String[]{(String) params.get(HasTextCol.TEXT_COL)};
    }

    public static String prependPrefix(String str) {
        return SAFE_PREFIX + str;
    }

    protected abstract EncodingKeys[] calcOutputKeys(Params params);

    @Override // com.alibaba.alink.common.mapper.Mapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, Params params) {
        String[] calcSelectedCols = calcSelectedCols(params);
        String[] strArr = (String[]) Arrays.stream(calcOutputKeys(params)).map(encodingKeys -> {
            return prependPrefix(encodingKeys.label);
        }).toArray(i -> {
            return new String[i];
        });
        TypeInformation[] typeInformationArr = new TypeInformation[strArr.length];
        Arrays.fill(typeInformationArr, AlinkTypes.INT_TENSOR);
        return Tuple4.of(calcSelectedCols, strArr, typeInformationArr, params.get(HasReservedColsDefaultAsNull.RESERVED_COLS));
    }
}
