package com.alibaba.alink.pipeline;

import com.alibaba.alink.common.LocalMLEnvironment;
import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.lazy.HasLazyPrintModelInfo;
import com.alibaba.alink.common.lazy.HasLazyPrintTrainInfo;
import com.alibaba.alink.common.lazy.HasLazyPrintTransformInfo;
import com.alibaba.alink.common.lazy.WithTrainInfoLocalOp;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp;
import com.alibaba.alink.operator.batch.utils.WithTrainInfo;
import com.alibaba.alink.operator.local.LocalOperator;
import com.alibaba.alink.operator.local.lazy.WithModelInfoLocalOp;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.params.ModelStreamScanParams;
import com.alibaba.alink.pipeline.MapModel;
import com.alibaba.alink.pipeline.TrainerLegacy;
import java.lang.reflect.ParameterizedType;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/pipeline/TrainerLegacy.class */
public abstract class TrainerLegacy<T extends TrainerLegacy<T, M>, M extends MapModel<M>> extends EstimatorBase<T, M> implements ModelStreamScanParams<T>, HasLazyPrintTransformInfo<T> {
    public TrainerLegacy() {
    }

    public TrainerLegacy(Params params) {
        super(params);
    }

    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public M fit(BatchOperator<?> batchOperator) {
        M createModel = createModel(postProcessTrainOp(train(batchOperator)));
        checkModelValidity((TrainerLegacy<T, M>) createModel, batchOperator);
        return postProcessModel(createModel);
    }

    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public M fit(LocalOperator<?> localOperator) {
        M createModel = createModel(postProcessTrainOp(train(localOperator)));
        checkModelValidity((TrainerLegacy<T, M>) createModel, localOperator);
        return postProcessModel(createModel);
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected BatchOperator<?> postProcessTrainOp(BatchOperator<?> batchOperator) {
        MLEnvironmentFactory.get(batchOperator.getMLEnvironmentId()).getLazyObjectsManager().genLazyTrainOp((TrainerLegacy<?, ?>) this).addValue(batchOperator);
        if ((this instanceof HasLazyPrintTrainInfo) && ((Boolean) get(HasLazyPrintTrainInfo.LAZY_PRINT_TRAIN_INFO_ENABLED)).booleanValue()) {
            ((WithTrainInfo) batchOperator).lazyPrintTrainInfo((String) get(HasLazyPrintTrainInfo.LAZY_PRINT_TRAIN_INFO_TITLE));
        }
        if ((this instanceof HasLazyPrintModelInfo) && ((Boolean) get(HasLazyPrintModelInfo.LAZY_PRINT_MODEL_INFO_ENABLED)).booleanValue()) {
            ((WithModelInfoBatchOp) batchOperator).lazyPrintModelInfo((String) get(HasLazyPrintModelInfo.LAZY_PRINT_MODEL_INFO_TITLE));
        }
        return batchOperator;
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected LocalOperator<?> postProcessTrainOp(LocalOperator<?> localOperator) {
        LocalMLEnvironment.getInstance().getLazyObjectsManager().genLazyTrainOp((TrainerLegacy<?, ?>) this).addValue(localOperator);
        if ((this instanceof HasLazyPrintTrainInfo) && ((Boolean) get(HasLazyPrintTrainInfo.LAZY_PRINT_TRAIN_INFO_ENABLED)).booleanValue()) {
            ((WithTrainInfoLocalOp) localOperator).lazyPrintTrainInfo((String) get(HasLazyPrintTrainInfo.LAZY_PRINT_TRAIN_INFO_TITLE));
        }
        if ((this instanceof HasLazyPrintModelInfo) && ((Boolean) get(HasLazyPrintModelInfo.LAZY_PRINT_MODEL_INFO_ENABLED)).booleanValue()) {
            ((WithModelInfoLocalOp) localOperator).lazyPrintModelInfo((String) get(HasLazyPrintModelInfo.LAZY_PRINT_MODEL_INFO_TITLE));
        }
        return localOperator;
    }

    protected void checkModelValidity(M m, BatchOperator<?> batchOperator) {
        if (m instanceof MapModel) {
            m.validate(m.getModelData().getSchema(), batchOperator.getSchema());
        }
    }

    protected void checkModelValidity(M m, LocalOperator<?> localOperator) {
        if (m instanceof MapModel) {
            m.validate(m.getModelDataLocal().getSchema(), localOperator.getSchema());
        }
    }

    protected M postProcessModel(M m) {
        MLEnvironmentFactory.get(m.getMLEnvironmentId()).getLazyObjectsManager().genLazyModel((TrainerLegacy<?, ?>) this).addValue(m);
        if (this instanceof HasLazyPrintTransformInfo) {
            if (((Boolean) get(LAZY_PRINT_TRANSFORM_DATA_ENABLED)).booleanValue()) {
                m.enableLazyPrintTransformData(((Integer) get(LAZY_PRINT_TRANSFORM_DATA_NUM)).intValue(), (String) get(LAZY_PRINT_TRANSFORM_DATA_TITLE));
            }
            if (((Boolean) get(LAZY_PRINT_TRANSFORM_STAT_ENABLED)).booleanValue()) {
                m.enableLazyPrintTransformStat((String) get(LAZY_PRINT_TRANSFORM_STAT_TITLE));
            }
        }
        return m;
    }

    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public M fit(StreamOperator<?> streamOperator) {
        throw new AkUnsupportedOperationException("Only support batch or local fit!");
    }

    private M createModel(BatchOperator<?> batchOperator) {
        try {
            return (M) ((MapModel) ((Class) ((ParameterizedType) getClass().getGenericSuperclass()).getActualTypeArguments()[1]).getConstructor(Params.class).newInstance(getParams())).setModelData(batchOperator);
        } catch (ExceptionWithErrorCode e) {
            throw e;
        } catch (Exception e2) {
            throw new AkUnclassifiedErrorException("Error. ", e2);
        }
    }

    private M createModel(LocalOperator<?> localOperator) {
        try {
            return (M) ((MapModel) ((Class) ((ParameterizedType) getClass().getGenericSuperclass()).getActualTypeArguments()[1]).getConstructor(Params.class).newInstance(getParams())).setModelData(localOperator);
        } catch (ExceptionWithErrorCode e) {
            throw e;
        } catch (Exception e2) {
            throw new AkUnclassifiedErrorException("Error. ", e2);
        }
    }

    protected abstract BatchOperator<?> train(BatchOperator<?> batchOperator);

    protected StreamOperator<?> train(StreamOperator<?> streamOperator) {
        throw new AkUnsupportedOperationException("Only support batch fit!");
    }

    protected LocalOperator<?> train(LocalOperator<?> localOperator) {
        throw new AkUnsupportedOperationException("Not supported yet!");
    }

    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public /* bridge */ /* synthetic */ ModelBase fit(LocalOperator localOperator) {
        return fit((LocalOperator<?>) localOperator);
    }

    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public /* bridge */ /* synthetic */ ModelBase fit(StreamOperator streamOperator) {
        return fit((StreamOperator<?>) streamOperator);
    }

    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public /* bridge */ /* synthetic */ ModelBase fit(BatchOperator batchOperator) {
        return fit((BatchOperator<?>) batchOperator);
    }
}
