package com.alibaba.alink.common.dl;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.Internal;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.dl.BaseKerasSequentialTrainBatchOp;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp;
import com.alibaba.alink.operator.batch.source.TableSourceBatchOp;
import com.alibaba.alink.operator.batch.tensorflow.TF2TableModelTrainBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.classification.tensorflow.TFTableModelClassificationModelDataConverter;
import com.alibaba.alink.operator.common.clustering.dbscan.DbscanConstant;
import com.alibaba.alink.operator.common.regression.tensorflow.TFTableModelRegressionModelDataConverter;
import com.alibaba.alink.operator.common.tensorflow.CommonUtils;
import com.alibaba.alink.params.dl.HasPythonEnv;
import com.alibaba.alink.params.dl.HasTaskType;
import com.alibaba.alink.params.tensorflow.kerasequential.BaseKerasSequentialTrainParams;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.FlatMapOperator;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.StringUtils;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@Internal
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "tensorCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES}), @ParamSelectColumnSpec(name = "labelCol")})
/* loaded from: input_file:com/alibaba/alink/common/dl/BaseKerasSequentialTrainBatchOp.class */
public class BaseKerasSequentialTrainBatchOp<T extends BaseKerasSequentialTrainBatchOp<T>> extends BatchOperator<T> implements BaseKerasSequentialTrainParams<T> {
    static final String TF_OUTPUT_SIGNATURE_DEF_CLASSIFICATION = "logits";
    static final String TF_OUTPUT_SIGNATURE_DEF_REGRESSION = "y";
    static final TypeInformation<?> TF_OUTPUT_SIGNATURE_TYPE = AlinkTypes.FLOAT_TENSOR;
    private static final String MAIN_SCRIPT_FILE_NAME = "res:///tf_algos/train_keras_sequential.py";
    private static final String[] RES_PY_FILES = {MAIN_SCRIPT_FILE_NAME};

    public BaseKerasSequentialTrainBatchOp() {
        this(null);
    }

    public BaseKerasSequentialTrainBatchOp(Params params) {
        super(params);
    }

    public static String getTfOutputSignatureDef(TaskType taskType) {
        return TaskType.CLASSIFICATION.equals(taskType) ? TF_OUTPUT_SIGNATURE_DEF_CLASSIFICATION : TF_OUTPUT_SIGNATURE_DEF_REGRESSION;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public T linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> batchOperator = batchOperatorArr[0];
        Params params = getParams();
        TaskType taskType = (TaskType) params.get(HasTaskType.TASK_TYPE);
        boolean equals = TaskType.REGRESSION.equals(taskType);
        String tensorCol = getTensorCol();
        String labelCol = getLabelCol();
        TypeInformation<?> findColTypeWithAssertAndHint = TableUtil.findColTypeWithAssertAndHint(batchOperator.getSchema(), labelCol);
        DataSet dataSet = null;
        BatchOperator<?> batchOperator2 = null;
        if (!equals) {
            dataSet = batchOperator.select(labelCol).getDataSet().mapPartition(new MapPartitionFunction<Row, Object>() { // from class: com.alibaba.alink.common.dl.BaseKerasSequentialTrainBatchOp.2
                public void mapPartition(Iterable<Row> iterable, Collector<Object> collector) throws Exception {
                    HashSet hashSet = new HashSet();
                    Iterator<Row> it = iterable.iterator();
                    while (it.hasNext()) {
                        hashSet.add(it.next().getField(0));
                    }
                    Iterator it2 = hashSet.iterator();
                    while (it2.hasNext()) {
                        collector.collect(it2.next());
                    }
                }
            }).reduceGroup(new GroupReduceFunction<Object, List<Object>>() { // from class: com.alibaba.alink.common.dl.BaseKerasSequentialTrainBatchOp.1
                public void reduce(Iterable<Object> iterable, Collector<List<Object>> collector) throws Exception {
                    TreeSet treeSet = new TreeSet();
                    Iterator<Object> it = iterable.iterator();
                    while (it.hasNext()) {
                        treeSet.add(it.next());
                    }
                    collector.collect(new ArrayList(treeSet));
                }
            });
            batchOperator = CommonUtils.mapLabelToIndex(batchOperator, labelCol, dataSet);
            batchOperator2 = (BatchOperator) new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) dataSet.map(new CommonUtils.CountLabelsMapFunction()), new String[]{DbscanConstant.COUNT}, (TypeInformation<?>[]) new TypeInformation[]{Types.INT})).setMLEnvironmentId(getMLEnvironmentId());
        }
        Boolean removeCheckpointBeforeTraining = getRemoveCheckpointBeforeTraining();
        if (null == removeCheckpointBeforeTraining) {
            removeCheckpointBeforeTraining = true;
        }
        HashMap hashMap = new HashMap();
        hashMap.put("layers", getLayers());
        HashMap hashMap2 = new HashMap();
        if (removeCheckpointBeforeTraining.booleanValue()) {
            hashMap2.put(DLConstants.REMOVE_CHECKPOINT_BEFORE_TRAINING, "true");
        }
        hashMap2.put("tensor_cols", JsonConverter.toJson(new String[]{tensorCol}));
        hashMap2.put("label_col", labelCol);
        hashMap2.put("label_type", "float");
        hashMap2.put("batch_size", String.valueOf(getBatchSize()));
        hashMap2.put("num_epochs", String.valueOf(getNumEpochs()));
        hashMap2.put("model_config", JsonConverter.toJson(hashMap));
        hashMap2.put("optimizer", getOptimizer());
        if (!StringUtils.isNullOrWhitespaceOnly(getCheckpointFilePath())) {
            hashMap2.put("model_dir", getCheckpointFilePath());
        }
        if (MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().getParallelism() == 1) {
            hashMap2.put("ALINK:ONLY_ONE_WORKER", "true");
        }
        hashMap2.put("validation_split", String.valueOf(getValidationSplit()));
        hashMap2.put("save_best_only", String.valueOf(getSaveBestOnly()));
        hashMap2.put("best_exporter_metric", getBestMetric());
        hashMap2.put("save_checkpoints_epochs", String.valueOf(getSaveCheckpointsEpochs()));
        if (params.contains(BaseKerasSequentialTrainParams.SAVE_CHECKPOINTS_SECS)) {
            hashMap2.put("save_checkpoints_secs", String.valueOf(getSaveCheckpointsSecs()));
        }
        TF2TableModelTrainBatchOp pythonEnv = new TF2TableModelTrainBatchOp(params).setSelectedCols(tensorCol, labelCol).setUserFiles(RES_PY_FILES).setMainScriptFile(MAIN_SCRIPT_FILE_NAME).setUserParams(JsonConverter.toJson(hashMap2)).setIntraOpParallelism(getIntraOpParallelism()).setNumPSs(getNumPSs()).setNumWorkers(getNumWorkers()).setPythonEnv((String) params.get(HasPythonEnv.PYTHON_ENV));
        FlatMapOperator withBroadcastSet = ((NumSeqSourceBatchOp) new NumSeqSourceBatchOp().setFrom(0L).setTo(0L).setMLEnvironmentId(getMLEnvironmentId())).getDataSet().flatMap(new CommonUtils.ConstructModelFlatMapFunction(params, new String[]{tensorCol}, getTfOutputSignatureDef(taskType), TF_OUTPUT_SIGNATURE_TYPE, null, true)).withBroadcastSet((equals ? pythonEnv.linkFrom(batchOperator) : pythonEnv.linkFrom(batchOperator, batchOperator2)).getDataSet(), CommonUtils.TF_MODEL_BC_NAME);
        setOutputTable((equals ? (BatchOperator) new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) withBroadcastSet, new TFTableModelRegressionModelDataConverter(findColTypeWithAssertAndHint).getModelSchema())).setMLEnvironmentId(getMLEnvironmentId()) : (BatchOperator) new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) withBroadcastSet.withBroadcastSet(dataSet, CommonUtils.SORTED_LABELS_BC_NAME), new TFTableModelClassificationModelDataConverter(findColTypeWithAssertAndHint).getModelSchema())).setMLEnvironmentId(getMLEnvironmentId())).getOutputTable());
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ BatchOperator linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
