package com.alibaba.alink.common.dl;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.Internal;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.dl.BaseEasyTransferTrainBatchOp;
import com.alibaba.alink.common.dl.utils.PythonFileUtils;
import com.alibaba.alink.common.io.plugin.ResourcePluginFactory;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.TableSourceBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.classification.tensorflow.TFTableModelClassificationModelDataConverter;
import com.alibaba.alink.operator.common.nlp.bert.BertTokenizerMapper;
import com.alibaba.alink.operator.common.nlp.bert.tokenizer.EncodingKeys;
import com.alibaba.alink.operator.common.tensorflow.CommonUtils;
import com.alibaba.alink.params.dl.HasBatchSizeDefaultAs32;
import com.alibaba.alink.params.dl.HasCheckpointFilePathDefaultAsNull;
import com.alibaba.alink.params.dl.HasIntraOpParallelism;
import com.alibaba.alink.params.dl.HasLearningRateDefaultAs0001;
import com.alibaba.alink.params.dl.HasModelPath;
import com.alibaba.alink.params.dl.HasNumPssDefaultAsNull;
import com.alibaba.alink.params.dl.HasNumWorkersDefaultAsNull;
import com.alibaba.alink.params.dl.HasPythonEnv;
import com.alibaba.alink.params.dl.HasTaskType;
import com.alibaba.alink.params.dl.HasUserFiles;
import com.alibaba.alink.params.shared.colname.HasLabelCol;
import com.alibaba.alink.params.tensorflow.bert.HasBertModelName;
import com.alibaba.alink.params.tensorflow.bert.HasCustomConfigJson;
import com.alibaba.alink.params.tensorflow.bert.HasMaxSeqLength;
import com.alibaba.alink.params.tensorflow.bert.HasMaxSeqLengthDefaultAsNull;
import com.alibaba.alink.params.tensorflow.bert.HasNumEpochsDefaultAs001;
import com.alibaba.alink.params.tensorflow.bert.HasNumFineTunedLayersDefaultAs1;
import com.alibaba.alink.params.tensorflow.bert.HasTaskName;
import com.alibaba.alink.params.tensorflow.bert.HasTextCol;
import com.alibaba.alink.params.tensorflow.bert.HasTextPairCol;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.TransformerBase;
import com.alibaba.alink.pipeline.nlp.BertTokenizer;
import com.google.common.collect.ImmutableMap;
import com.google.gson.reflect.TypeToken;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.MapPartitionOperator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@Internal
/* loaded from: input_file:com/alibaba/alink/common/dl/BaseEasyTransferTrainBatchOp.class */
public class BaseEasyTransferTrainBatchOp<T extends BaseEasyTransferTrainBatchOp<T>> extends BatchOperator<T> {
    private static final Logger LOG = LoggerFactory.getLogger(BaseEasyTransferTrainBatchOp.class);
    private static final String[] MODEL_INPUTS = {EncodingKeys.INPUT_IDS_KEY.label, EncodingKeys.TOKEN_TYPE_IDS_KEY.label, EncodingKeys.ATTENTION_MASK_KEY.label};
    private static final String[] SAFE_MODEL_INPUTS = (String[]) Arrays.stream(MODEL_INPUTS).map(str -> {
        return BertTokenizerMapper.prependPrefix(str);
    }).toArray(i -> {
        return new String[i];
    });
    private final ResourcePluginFactory factory;

    public BaseEasyTransferTrainBatchOp() {
        this(new Params());
    }

    public BaseEasyTransferTrainBatchOp(Params params) {
        super(params);
        this.factory = new ResourcePluginFactory();
    }

    public static Map<String, Object> getPreprocessConfig(Params params, boolean z) {
        HashMap hashMap = new HashMap();
        String str = (String) params.get(HasTextCol.TEXT_COL);
        String str2 = params.contains(HasTextPairCol.TEXT_PAIR_COL) ? (String) params.get(HasTextPairCol.TEXT_PAIR_COL) : null;
        String str3 = (String) params.get(HasLabelCol.LABEL_COL);
        int intValue = ((Integer) params.get(HasMaxSeqLength.MAX_SEQ_LENGTH)).intValue();
        TaskType taskType = (TaskType) params.get(HasTaskType.TASK_TYPE);
        hashMap.put("input_schema", z ? null == str2 ? String.format("%s:float:1,%s:str:1", str3, str) : String.format("%s:float:1,%s:str:1,%s:str:1", str3, str, str2) : String.format("%s:int:%d,%s:int:%d,%s:int:%d,%s:%s:1", SAFE_MODEL_INPUTS[0], Integer.valueOf(intValue), SAFE_MODEL_INPUTS[1], Integer.valueOf(intValue), SAFE_MODEL_INPUTS[2], Integer.valueOf(intValue), str3, TaskType.CLASSIFICATION.equals(taskType) ? "int" : "float"));
        if (z) {
            hashMap.put("first_sequence", str);
            if (null != str2) {
                hashMap.put("second_sequence", str2);
            }
        }
        hashMap.put("sequence_length", Integer.valueOf(intValue));
        hashMap.put("label_name", str3);
        if (TaskType.CLASSIFICATION.equals(taskType)) {
            hashMap.put("num_labels", 2);
            hashMap.put("label_enumerate_values", "0.0,1.0");
        } else {
            hashMap.put("num_labels", 1);
        }
        return hashMap;
    }

    public static Map<String, Object> getModelConfig(Params params, ResourcePluginFactory resourcePluginFactory) {
        HashMap hashMap = new HashMap();
        if (!params.contains(HasModelPath.MODEL_PATH) || null == params.get(HasModelPath.MODEL_PATH)) {
            params.set((ParamInfo<ParamInfo<String>>) HasModelPath.MODEL_PATH, (ParamInfo<String>) BertResources.getBertModelCkpt(resourcePluginFactory, (String) params.get(HasBertModelName.BERT_MODEL_NAME)));
        }
        if (TaskType.CLASSIFICATION.equals((TaskType) params.get(HasTaskType.TASK_TYPE))) {
            hashMap.put("num_labels", 2);
        } else {
            hashMap.put("num_labels", 1);
        }
        int intValue = ((Integer) params.get(HasNumFineTunedLayersDefaultAs1.NUM_FINE_TUNED_LAYERS)).intValue();
        hashMap.put("dropout_rate", Double.valueOf(0.3d));
        hashMap.put("num_freezed_layers", Integer.valueOf(Math.max(0, 12 - intValue)));
        hashMap.put("keep_checkpoint_max", 1);
        return hashMap;
    }

    public static Map<String, Object> getTrainConfig(Params params) {
        HashMap hashMap = new HashMap();
        int intValue = ((Integer) params.get(HasBatchSizeDefaultAs32.BATCH_SIZE)).intValue();
        double doubleValue = ((Double) params.get(HasNumEpochsDefaultAs001.NUM_EPOCHS)).doubleValue();
        double doubleValue2 = ((Double) params.get(HasLearningRateDefaultAs0001.LEARNING_RATE)).doubleValue();
        hashMap.put("train_batch_size", Integer.valueOf(intValue));
        hashMap.put("save_steps", 100);
        hashMap.put("num_epochs", Double.valueOf(doubleValue));
        hashMap.put("optimizer_config", ImmutableMap.of("learning_rate", Double.valueOf(doubleValue2)));
        return hashMap;
    }

    public static Map<String, Object> getExportConfig(Params params) {
        String str = (String) params.get(HasLabelCol.LABEL_COL);
        int intValue = ((Integer) params.get(HasMaxSeqLength.MAX_SEQ_LENGTH)).intValue();
        String str2 = TaskType.CLASSIFICATION.equals((TaskType) params.get(HasTaskType.TASK_TYPE)) ? "int" : "float";
        HashMap hashMap = new HashMap();
        hashMap.put("input_tensors_schema", String.format("%s:int:%d,%s:int:%d,%s:int:%d,%s:%s:1", SAFE_MODEL_INPUTS[0], Integer.valueOf(intValue), SAFE_MODEL_INPUTS[1], Integer.valueOf(intValue), SAFE_MODEL_INPUTS[2], Integer.valueOf(intValue), str, str2));
        hashMap.put("receiver_tensors_schema", String.format("%s:int:%d,%s:int:%d,%s:int:%d", SAFE_MODEL_INPUTS[0], Integer.valueOf(intValue), SAFE_MODEL_INPUTS[1], Integer.valueOf(intValue), SAFE_MODEL_INPUTS[2], Integer.valueOf(intValue)));
        return hashMap;
    }

    public static Map<String, Object> mergeMap(Map<String, Object> map, Map<String, Object> map2) {
        if (null == map2) {
            return new HashMap(map);
        }
        HashSet<String> hashSet = new HashSet(map.keySet());
        hashSet.addAll(map2.keySet());
        HashMap hashMap = new HashMap();
        for (String str : hashSet) {
            Object obj = map.get(str);
            Object obj2 = map2.get(str);
            if (null == obj) {
                hashMap.put(str, obj2);
            } else if (null == obj2) {
                hashMap.put(str, obj);
            } else if ((obj instanceof Map) && (obj2 instanceof Map)) {
                hashMap.put(str, mergeMap((Map) obj, (Map) obj2));
            } else {
                hashMap.put(str, obj2);
            }
        }
        return hashMap;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v18, types: [java.util.Map] */
    /* JADX WARN: Type inference failed for: r1v8, types: [com.alibaba.alink.common.dl.BaseEasyTransferTrainBatchOp$1] */
    public static Map<String, Map<String, Object>> getConfig(Params params, boolean z, ResourcePluginFactory resourcePluginFactory) {
        HashMap hashMap = params.contains(HasCustomConfigJson.CUSTOM_CONFIG_JSON) ? (Map) JsonConverter.fromJson((String) params.get(HasCustomConfigJson.CUSTOM_CONFIG_JSON), new TypeToken<Map<String, Map<String, Object>>>() { // from class: com.alibaba.alink.common.dl.BaseEasyTransferTrainBatchOp.1
        }.getType()) : new HashMap();
        HashMap hashMap2 = new HashMap();
        hashMap2.put("preprocess_config", mergeMap(getPreprocessConfig(params, z), (Map) hashMap.get("preprocess_config")));
        hashMap2.put("model_config", mergeMap(getModelConfig(params, resourcePluginFactory), (Map) hashMap.get("model_config")));
        hashMap2.put("train_config", mergeMap(getTrainConfig(params), (Map) hashMap.get("train_config")));
        hashMap2.put("export_config", mergeMap(getExportConfig(params), (Map) hashMap.get("export_config")));
        return hashMap2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public T linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> batchOperator = batchOperatorArr[0];
        Params params = getParams();
        TaskType taskType = (TaskType) params.get(HasTaskType.TASK_TYPE);
        String str = (String) params.get(HasLabelCol.LABEL_COL);
        TypeInformation<?> findColTypeWithAssertAndHint = TableUtil.findColTypeWithAssertAndHint(batchOperator.getSchema(), str);
        DataSet dataSet = null;
        if (TaskType.CLASSIFICATION.equals(taskType)) {
            dataSet = batchOperator.select(str).distinct().getDataSet().reduceGroup(new CommonUtils.SortLabelsReduceGroupFunction());
            batchOperator = EasyTransferUtils.mapLabelToIntIndex(batchOperator, str, dataSet);
        }
        PipelineModel pipelineModel = new PipelineModel((TransformerBase<?>[]) new TransformerBase[]{(BertTokenizer) new BertTokenizer(params.m1495clone()).set(HasMaxSeqLengthDefaultAsNull.MAX_SEQ_LENGTH, params.get(HasMaxSeqLength.MAX_SEQ_LENGTH))});
        BatchOperator<?> transform = pipelineModel.transform(batchOperator);
        BatchOperator<?> save = pipelineModel.save();
        String schema2SchemaStr = TableUtil.schema2SchemaStr(save.getSchema());
        HashMap hashMap = new HashMap();
        String bertModelCkpt = (!params.contains(HasModelPath.MODEL_PATH) || null == params.get(HasModelPath.MODEL_PATH)) ? BertResources.getBertModelCkpt(this.factory, (String) params.get(HasBertModelName.BERT_MODEL_NAME)) : (String) params.get(HasModelPath.MODEL_PATH);
        String str2 = (String) params.get(HasCheckpointFilePathDefaultAsNull.CHECKPOINT_FILE_PATH);
        if (!StringUtils.isNullOrWhitespaceOnly(str2)) {
            hashMap.put("model_dir", str2);
        }
        ExternalFilesConfig fromJson = params.contains(HasUserFiles.USER_FILES) ? ExternalFilesConfig.fromJson((String) params.get(HasUserFiles.USER_FILES)) : new ExternalFilesConfig();
        if (PythonFileUtils.isLocalFile(bertModelCkpt)) {
            hashMap.put("pretrained_ckpt_path", bertModelCkpt.substring("file://".length()));
        } else {
            fromJson.addFilePaths(bertModelCkpt);
            hashMap.put("pretrained_ckpt_path", PythonFileUtils.getCompressedFileName(bertModelCkpt));
        }
        String json = JsonConverter.toJson(getConfig(getParams(), false, this.factory));
        LOG.info("EasyTransfer config: {}", json);
        if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
            System.out.println("EasyTransfer config: " + json);
        }
        hashMap.put("app_name", ((BertTaskName) params.get(HasTaskName.TASK_NAME)).name());
        EasyTransferConfigTrainBatchOp easyTransferConfigTrainBatchOp = (EasyTransferConfigTrainBatchOp) new EasyTransferConfigTrainBatchOp().setSelectedCols((String[]) ArrayUtils.add(SAFE_MODEL_INPUTS, str)).setConfigJson(json).setUserFiles(fromJson).setUserParams(JsonConverter.toJson(hashMap)).setNumWorkers((Integer) params.get(HasNumWorkersDefaultAsNull.NUM_WORKERS)).setNumPSs((Integer) params.get(HasNumPssDefaultAsNull.NUM_PSS)).setPythonEnv((String) params.get(HasPythonEnv.PYTHON_ENV)).setIntraOpParallelism((Integer) params.get(HasIntraOpParallelism.INTRA_OP_PARALLELISM)).setMLEnvironmentId(getMLEnvironmentId());
        BatchOperator<?>[] batchOperatorArr2 = new BatchOperator[batchOperatorArr.length];
        batchOperatorArr2[0] = transform;
        System.arraycopy(batchOperatorArr, 1, batchOperatorArr2, 1, batchOperatorArr.length - 1);
        SingleInputUdfOperator singleInputUdfOperator = (MapPartitionOperator) easyTransferConfigTrainBatchOp.linkFrom(batchOperatorArr2).getDataSet().partitionCustom(new Partitioner<Long>() { // from class: com.alibaba.alink.common.dl.BaseEasyTransferTrainBatchOp.2
            public int partition(Long l, int i) {
                return 0;
            }
        }, 0).mapPartition(new CommonUtils.ConstructModelMapPartitionFunction(params, SAFE_MODEL_INPUTS, EasyTransferUtils.getTfOutputSignatureDef(taskType), EasyTransferUtils.TF_OUTPUT_SIGNATURE_TYPE, schema2SchemaStr)).withBroadcastSet(save.getDataSet(), CommonUtils.PREPROCESS_PIPELINE_MODEL_BC_NAME);
        setOutputTable(((BatchOperator) new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) (TaskType.CLASSIFICATION.equals(taskType) ? singleInputUdfOperator.withBroadcastSet(dataSet, CommonUtils.SORTED_LABELS_BC_NAME) : singleInputUdfOperator), new TFTableModelClassificationModelDataConverter(findColTypeWithAssertAndHint).getModelSchema())).setMLEnvironmentId(getMLEnvironmentId())).getOutputTable());
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ BatchOperator linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
