package com.alibaba.alink.pipeline;

import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp;
import com.alibaba.alink.operator.local.LocalOperator;
import com.alibaba.alink.operator.local.utils.ModelMapLocalOp;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.utils.ModelMapStreamOp;
import com.alibaba.alink.params.ModelStreamScanParams;
import com.alibaba.alink.pipeline.MapModel;
import java.util.List;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.function.TriFunction;

/* loaded from: input_file:com/alibaba/alink/pipeline/MapModel.class */
public abstract class MapModel<T extends MapModel<T>> extends ModelBase<T> implements ModelStreamScanParams<T>, LocalPredictable {
    private static final long serialVersionUID = 8333228095437207694L;
    final TriFunction<TableSchema, TableSchema, Params, ModelMapper> mapperBuilder;

    /* JADX INFO: Access modifiers changed from: protected */
    public MapModel(TriFunction<TableSchema, TableSchema, Params, ModelMapper> triFunction, Params params) {
        super(params);
        this.mapperBuilder = (TriFunction) AkPreconditions.checkNotNull(triFunction, "mapperBuilder can not be null");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void validate(TableSchema tableSchema, TableSchema tableSchema2) {
        this.mapperBuilder.apply(tableSchema, tableSchema2, this.params);
    }

    @Override // com.alibaba.alink.pipeline.TransformerBase
    public BatchOperator<?> transform(BatchOperator<?> batchOperator) {
        return postProcessTransformResult(new ModelMapBatchOp(this.mapperBuilder, this.params).linkFrom(getModelData(), batchOperator));
    }

    @Override // com.alibaba.alink.pipeline.TransformerBase
    public StreamOperator<?> transform(StreamOperator<?> streamOperator) {
        return new ModelMapStreamOp(getModelData(), this.mapperBuilder, this.params).linkFrom(streamOperator);
    }

    @Override // com.alibaba.alink.pipeline.TransformerBase
    public LocalOperator<?> transform(LocalOperator<?> localOperator) {
        return postProcessTransformResult(new ModelMapLocalOp(this.mapperBuilder, this.params).linkFrom(getModelDataLocal(), localOperator));
    }

    @Override // com.alibaba.alink.pipeline.LocalPredictable
    public LocalPredictor collectLocalPredictor(TableSchema tableSchema) throws Exception {
        List<Row> collect = getModelData().collect();
        ModelMapper modelMapper = (ModelMapper) this.mapperBuilder.apply(getModelData().getSchema(), tableSchema, getParams());
        modelMapper.loadModel(collect);
        return new LocalPredictor(modelMapper);
    }
}
