package com.alibaba.alink.pipeline;

import com.alibaba.alink.common.exceptions.AkIllegalOperationException;
import com.alibaba.alink.common.io.filesystem.FilePath;
import com.alibaba.alink.common.lazy.HasLazyPrintTransformInfo;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.sink.AkSinkBatchOp;
import com.alibaba.alink.operator.batch.source.AkSourceBatchOp;
import com.alibaba.alink.operator.local.LocalOperator;
import com.alibaba.alink.operator.local.sink.AkSinkLocalOp;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.params.PipelineModelParams;
import com.alibaba.alink.pipeline.ModelExporterUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/pipeline/Pipeline.class */
public final class Pipeline extends EstimatorBase<Pipeline, PipelineModel> {
    private static final long serialVersionUID = 1562871813230757217L;
    ArrayList<PipelineStageBase<?>> stages;

    public Pipeline() {
        this(new Params());
    }

    public Pipeline(Params params) {
        super(params);
        this.stages = new ArrayList<>();
    }

    public Pipeline(PipelineStageBase<?>... pipelineStageBaseArr) {
        super(null);
        this.stages = new ArrayList<>();
        if (null != pipelineStageBaseArr) {
            this.stages.addAll(Arrays.asList(pipelineStageBaseArr));
        }
    }

    @Override // com.alibaba.alink.pipeline.PipelineStageBase
    /* renamed from: clone */
    public Pipeline mo1482clone() throws CloneNotSupportedException {
        Pipeline pipeline = new Pipeline();
        Iterator<PipelineStageBase<?>> it = this.stages.iterator();
        while (it.hasNext()) {
            pipeline.add(it.next().mo1482clone());
        }
        return pipeline;
    }

    public Pipeline add(PipelineStageBase<?> pipelineStageBase) {
        this.stages.add(pipelineStageBase);
        return this;
    }

    @Deprecated
    public Pipeline add(int i, PipelineStageBase<?> pipelineStageBase) {
        this.stages.add(i, pipelineStageBase);
        return this;
    }

    @Deprecated
    public Pipeline remove(int i) {
        this.stages.remove(i);
        return this;
    }

    public PipelineStageBase<?> get(int i) {
        return this.stages.get(i);
    }

    public int size() {
        return this.stages.size();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public PipelineModel fit(BatchOperator<?> batchOperator) {
        PipelineModel pipelineModel = (PipelineModel) new PipelineModel((TransformerBase<?>[]) fit(batchOperator, false).f0).setMLEnvironmentId(batchOperator.getMLEnvironmentId());
        pipelineModel.getParams().set((ParamInfo<ParamInfo<String>>) PipelineModelParams.TRAINING_DATA_SCHEMA, (ParamInfo<String>) TableUtil.schema2SchemaStr(batchOperator.getSchema()));
        return pipelineModel;
    }

    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public BatchOperator<?> fitAndTransform(BatchOperator<?> batchOperator) {
        return (BatchOperator) fit(batchOperator, true).f1;
    }

    private Tuple2<TransformerBase<?>[], BatchOperator<?>> fit(BatchOperator<?> batchOperator, boolean z) {
        Iterator<PipelineStageBase<?>> it = this.stages.iterator();
        while (it.hasNext()) {
            PipelineStageBase<?> next = it.next();
            if ((next instanceof Trainer) && EstimatorTrainerCatalog.lookupBatchTrainer(next.getClass().getName()) == null) {
                throw new AkIllegalOperationException("Pipeline can't fit BatchOperator, for the Estimator(" + next.getClass().getName() + ") not support.");
            }
        }
        TableSchema schema = batchOperator.getSchema();
        int indexOfLastEstimator = getIndexOfLastEstimator();
        TransformerBase[] transformerBaseArr = new TransformerBase[this.stages.size()];
        for (int i = 0; i < this.stages.size(); i++) {
            PipelineStageBase<?> pipelineStageBase = this.stages.get(i);
            if (i > indexOfLastEstimator) {
                transformerBaseArr[i] = (TransformerBase) pipelineStageBase;
            } else if (pipelineStageBase instanceof EstimatorBase) {
                transformerBaseArr[i] = ((EstimatorBase) pipelineStageBase).fit(batchOperator);
            } else if (pipelineStageBase instanceof TransformerBase) {
                transformerBaseArr[i] = (TransformerBase) pipelineStageBase;
            }
            if (i < indexOfLastEstimator || z) {
                Boolean bool = (Boolean) transformerBaseArr[i].get(HasLazyPrintTransformInfo.LAZY_PRINT_TRANSFORM_DATA_ENABLED);
                Boolean bool2 = (Boolean) transformerBaseArr[i].get(HasLazyPrintTransformInfo.LAZY_PRINT_TRANSFORM_STAT_ENABLED);
                transformerBaseArr[i].set(HasLazyPrintTransformInfo.LAZY_PRINT_TRANSFORM_DATA_ENABLED, false);
                transformerBaseArr[i].set(HasLazyPrintTransformInfo.LAZY_PRINT_TRANSFORM_STAT_ENABLED, false);
                batchOperator = transformerBaseArr[i].transform(batchOperator);
                transformerBaseArr[i].set(HasLazyPrintTransformInfo.LAZY_PRINT_TRANSFORM_DATA_ENABLED, bool);
                transformerBaseArr[i].set(HasLazyPrintTransformInfo.LAZY_PRINT_TRANSFORM_STAT_ENABLED, bool2);
            }
        }
        PipelineModel.getOutSchema(new PipelineModel((TransformerBase<?>[]) transformerBaseArr), schema);
        return new Tuple2<>(transformerBaseArr, batchOperator);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public PipelineModel fit(LocalOperator<?> localOperator) {
        PipelineModel pipelineModel = new PipelineModel((TransformerBase<?>[]) fit(localOperator, false).f0);
        pipelineModel.getParams().set((ParamInfo<ParamInfo<String>>) PipelineModelParams.TRAINING_DATA_SCHEMA, (ParamInfo<String>) TableUtil.schema2SchemaStr(localOperator.getSchema()));
        return pipelineModel;
    }

    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public LocalOperator<?> fitAndTransform(LocalOperator<?> localOperator) {
        return (LocalOperator) fit(localOperator, true).f1;
    }

    private Tuple2<TransformerBase<?>[], LocalOperator<?>> fit(LocalOperator<?> localOperator, boolean z) {
        Iterator<PipelineStageBase<?>> it = this.stages.iterator();
        while (it.hasNext()) {
            PipelineStageBase<?> next = it.next();
            if ((next instanceof Trainer) && EstimatorTrainerCatalog.lookupLocalTrainer(next.getClass().getName()) == null) {
                throw new AkIllegalOperationException("Pipeline can't fit LocalOperator, for the Estimator(" + next.getClass().getName() + ") not support.");
            }
        }
        int indexOfLastEstimator = getIndexOfLastEstimator();
        TransformerBase[] transformerBaseArr = new TransformerBase[this.stages.size()];
        for (int i = 0; i < this.stages.size(); i++) {
            PipelineStageBase<?> pipelineStageBase = this.stages.get(i);
            if (i > indexOfLastEstimator) {
                transformerBaseArr[i] = (TransformerBase) pipelineStageBase;
            } else if (pipelineStageBase instanceof EstimatorBase) {
                transformerBaseArr[i] = ((EstimatorBase) pipelineStageBase).fit(localOperator);
            } else if (pipelineStageBase instanceof TransformerBase) {
                transformerBaseArr[i] = (TransformerBase) pipelineStageBase;
            }
            if (i < indexOfLastEstimator || z) {
                Boolean bool = (Boolean) transformerBaseArr[i].get(HasLazyPrintTransformInfo.LAZY_PRINT_TRANSFORM_DATA_ENABLED);
                Boolean bool2 = (Boolean) transformerBaseArr[i].get(HasLazyPrintTransformInfo.LAZY_PRINT_TRANSFORM_STAT_ENABLED);
                transformerBaseArr[i].set(HasLazyPrintTransformInfo.LAZY_PRINT_TRANSFORM_DATA_ENABLED, false);
                transformerBaseArr[i].set(HasLazyPrintTransformInfo.LAZY_PRINT_TRANSFORM_STAT_ENABLED, false);
                localOperator = transformerBaseArr[i].transform(localOperator);
                transformerBaseArr[i].set(HasLazyPrintTransformInfo.LAZY_PRINT_TRANSFORM_DATA_ENABLED, bool);
                transformerBaseArr[i].set(HasLazyPrintTransformInfo.LAZY_PRINT_TRANSFORM_STAT_ENABLED, bool2);
            }
        }
        return new Tuple2<>(transformerBaseArr, localOperator);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public PipelineModel fit(StreamOperator<?> streamOperator) {
        int indexOfLastEstimator = getIndexOfLastEstimator();
        TransformerBase[] transformerBaseArr = new TransformerBase[this.stages.size()];
        for (int i = 0; i < this.stages.size(); i++) {
            PipelineStageBase<?> pipelineStageBase = this.stages.get(i);
            if (i <= indexOfLastEstimator) {
                if (pipelineStageBase instanceof EstimatorBase) {
                    transformerBaseArr[i] = ((EstimatorBase) pipelineStageBase).fit(streamOperator);
                } else if (pipelineStageBase instanceof TransformerBase) {
                    transformerBaseArr[i] = (TransformerBase) pipelineStageBase;
                }
                if (i < indexOfLastEstimator) {
                    streamOperator = transformerBaseArr[i].transform(streamOperator);
                }
            } else {
                transformerBaseArr[i] = (TransformerBase) pipelineStageBase;
            }
        }
        return (PipelineModel) new PipelineModel((TransformerBase<?>[]) transformerBaseArr).setMLEnvironmentId(streamOperator.getMLEnvironmentId());
    }

    private int getIndexOfLastEstimator() {
        int i = -1;
        for (int i2 = 0; i2 < this.stages.size(); i2++) {
            if (this.stages.get(i2) instanceof EstimatorBase) {
                i = i2;
            }
        }
        return i;
    }

    public void save(String str) {
        save(str, false);
    }

    public void save(String str, boolean z) {
        save(new FilePath(str), z);
    }

    public void save(FilePath filePath) {
        save(filePath, false);
    }

    public void save(FilePath filePath, boolean z) {
        save(filePath, z, 1);
    }

    public void save(FilePath filePath, boolean z, int i) {
        save(filePath, z, i, "auto");
    }

    public void save(FilePath filePath, boolean z, int i, String str) {
        String lowerCase = str.toLowerCase();
        if (lowerCase.equals("batch")) {
            saveBatch(filePath, z, i);
            return;
        }
        if (lowerCase.equals("local")) {
            saveLocal(filePath, z, i);
            return;
        }
        if (!lowerCase.equals("auto")) {
            throw new AkIllegalOperationException("Not support this save mode : " + lowerCase);
        }
        Tuple2<Boolean, Boolean> checkModels = PipelineModel.checkModels((PipelineStageBase[]) this.stages.toArray(new PipelineStageBase[0]));
        boolean booleanValue = ((Boolean) checkModels.f0).booleanValue();
        boolean booleanValue2 = ((Boolean) checkModels.f1).booleanValue();
        if (booleanValue && booleanValue2) {
            saveBatch(filePath, z, i);
        } else {
            saveLocal(filePath, z, i);
            ModelExporterUtils.createEmptyBatchSourceSink(getMLEnvironmentId());
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void saveBatch(FilePath filePath, boolean z, int i) {
        saveBatch().link(((AkSinkBatchOp) new AkSinkBatchOp().setMLEnvironmentId(getMLEnvironmentId())).setFilePath(filePath).setOverwriteSink(Boolean.valueOf(z)).setNumFiles(Integer.valueOf(i)));
    }

    private void saveLocal(FilePath filePath, boolean z, int i) {
        saveLocal().link(new AkSinkLocalOp().setFilePath(filePath).setOverwriteSink(Boolean.valueOf(z)).setNumFiles(Integer.valueOf(i)));
    }

    @Deprecated
    public BatchOperator<?> save() {
        return saveBatch();
    }

    private BatchOperator<?> saveBatch() {
        return ModelExporterUtils.serializePipelineStages(this.stages, this.params);
    }

    public LocalOperator<?> saveLocal() {
        return ModelExporterUtils.serializePipelineStagesLocal(this.stages, this.params);
    }

    public static Pipeline loadLocal(LocalOperator<?> localOperator) {
        return new Pipeline((PipelineStageBase<?>[]) ModelExporterUtils.fillPipelineStages(localOperator, (ModelExporterUtils.StageNode[]) ModelExporterUtils.collectMetaFromOp(localOperator).f0, localOperator.getSchema()).toArray(new PipelineStageBase[0]));
    }

    public static Pipeline load(String str) {
        return load(new FilePath(str));
    }

    public static Pipeline load(FilePath filePath) {
        Tuple2<TableSchema, Row> loadMetaFromAkFile = ModelExporterUtils.loadMetaFromAkFile(filePath);
        return new Pipeline((PipelineStageBase<?>[]) ModelExporterUtils.fillPipelineStages(new ModelPipeFileData(filePath), ModelExporterUtils.deserializePipelineStagesFromMeta((Row) loadMetaFromAkFile.f1, (TableSchema) loadMetaFromAkFile.f0), (TableSchema) loadMetaFromAkFile.f0).toArray(new PipelineStageBase[0]));
    }

    public static Pipeline collectLoad(BatchOperator<?> batchOperator) {
        return new Pipeline((PipelineStageBase<?>[]) ModelExporterUtils.fillPipelineStages(batchOperator, (ModelExporterUtils.StageNode[]) ModelExporterUtils.collectMetaFromOp(batchOperator).f0, batchOperator.getSchema()).toArray(new PipelineStageBase[0]));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Deprecated
    public static Pipeline load(FilePath filePath, Long l) {
        Tuple2<TableSchema, Row> loadMetaFromAkFile = ModelExporterUtils.loadMetaFromAkFile(filePath);
        return new Pipeline((PipelineStageBase<?>[]) ModelExporterUtils.fillPipelineStages((BatchOperator<?>) new AkSourceBatchOp().setFilePath(filePath).setMLEnvironmentId(l), ModelExporterUtils.deserializePipelineStagesFromMeta((Row) loadMetaFromAkFile.f1, (TableSchema) loadMetaFromAkFile.f0), (TableSchema) loadMetaFromAkFile.f0).toArray(new PipelineStageBase[0]));
    }

    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public /* bridge */ /* synthetic */ PipelineModel fit(LocalOperator localOperator) {
        return fit((LocalOperator<?>) localOperator);
    }

    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public /* bridge */ /* synthetic */ PipelineModel fit(StreamOperator streamOperator) {
        return fit((StreamOperator<?>) streamOperator);
    }

    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public /* bridge */ /* synthetic */ PipelineModel fit(BatchOperator batchOperator) {
        return fit((BatchOperator<?>) batchOperator);
    }
}
