package com.alibaba.alink.pipeline.tensorflow;

import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.tensorflow.TF2TableModelTrainBatchOp;
import com.alibaba.alink.params.tensorflow.TF2TableModelTrainParams;
import com.alibaba.alink.params.tensorflow.savedmodel.HasInferSelectedColsDefaultAsNull;
import com.alibaba.alink.params.tensorflow.savedmodel.TFTableModelPredictParams;
import com.alibaba.alink.pipeline.MapModel;
import com.alibaba.alink.pipeline.ModelBase;
import com.alibaba.alink.pipeline.TrainerLegacy;
import org.apache.flink.ml.api.misc.param.Params;

@NameCn("TF2表模型")
/* loaded from: input_file:com/alibaba/alink/pipeline/tensorflow/TF2TableModelTrainer.class */
public class TF2TableModelTrainer extends TrainerLegacy<TF2TableModelTrainer, TFTableModelPredictor> implements TF2TableModelTrainParams<TF2TableModelTrainer>, TFTableModelPredictParams<TF2TableModelTrainer>, HasInferSelectedColsDefaultAsNull<TF2TableModelTrainer> {
    public TF2TableModelTrainer() {
        this(null);
    }

    public TF2TableModelTrainer(Params params) {
        super(params);
    }

    @Override // com.alibaba.alink.pipeline.TrainerLegacy, com.alibaba.alink.pipeline.EstimatorBase
    public TFTableModelPredictor fit(BatchOperator<?> batchOperator) {
        TFTableModelPredictor tFTableModelPredictor = (TFTableModelPredictor) super.fit(batchOperator);
        tFTableModelPredictor.setSelectedCols(getInferSelectedCols());
        return tFTableModelPredictor;
    }

    @Override // com.alibaba.alink.pipeline.TrainerLegacy
    protected BatchOperator<?> train(BatchOperator<?> batchOperator) {
        return new TF2TableModelTrainBatchOp(getParams()).linkFrom(batchOperator);
    }

    @Override // com.alibaba.alink.pipeline.TrainerLegacy, com.alibaba.alink.pipeline.EstimatorBase
    public /* bridge */ /* synthetic */ MapModel fit(BatchOperator batchOperator) {
        return fit((BatchOperator<?>) batchOperator);
    }

    @Override // com.alibaba.alink.pipeline.TrainerLegacy, com.alibaba.alink.pipeline.EstimatorBase
    public /* bridge */ /* synthetic */ ModelBase fit(BatchOperator batchOperator) {
        return fit((BatchOperator<?>) batchOperator);
    }
}
