package com.alibaba.alink.pipeline.tuning;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.io.directreader.DefaultDistributedInfo;
import com.alibaba.alink.common.io.filesystem.copy.csv.CsvInputFormat;
import com.alibaba.alink.common.lazy.HasLazyPrintTrainInfo;
import com.alibaba.alink.common.lazy.LazyEvaluation;
import com.alibaba.alink.common.lazy.LazyObjectsManager;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.dataproc.SplitBatchOp;
import com.alibaba.alink.operator.batch.source.TableSourceBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.local.LocalOperator;
import com.alibaba.alink.operator.local.source.MemSourceLocalOp;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.params.tuning.ParallelTuningMode;
import com.alibaba.alink.pipeline.EstimatorBase;
import com.alibaba.alink.pipeline.ModelBase;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.TransformerBase;
import com.alibaba.alink.pipeline.tuning.BaseTuning;
import com.alibaba.alink.pipeline.tuning.BaseTuningModel;
import com.alibaba.alink.pipeline.tuning.Report;
import java.lang.reflect.ParameterizedType;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.types.DataType;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.ExceptionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@NameCn("")
/* loaded from: input_file:com/alibaba/alink/pipeline/tuning/BaseTuning.class */
public abstract class BaseTuning<T extends BaseTuning<T, M>, M extends BaseTuningModel<M>> extends EstimatorBase<T, M> implements HasLazyPrintTrainInfo<T>, ParallelTuningMode<T> {
    private static final Logger LOG = LoggerFactory.getLogger(BaseTuning.class);
    private static final long serialVersionUID = 7100530176503587968L;
    private EstimatorBase<?, ?> estimator;
    private TuningEvaluator<?> tuningEvaluator;

    public EstimatorBase<?, ?> getEstimator() {
        return this.estimator;
    }

    public T setEstimator(EstimatorBase<?, ?> estimatorBase) {
        this.estimator = estimatorBase;
        return this;
    }

    public TuningEvaluator<?> getTuningEvaluator() {
        return this.tuningEvaluator;
    }

    public T setTuningEvaluator(TuningEvaluator<?> tuningEvaluator) {
        this.tuningEvaluator = tuningEvaluator;
        return this;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public M fit(BatchOperator<?> batchOperator) {
        Tuple2<TransformerBase, Report> tuning = tuning(batchOperator);
        if (((Boolean) getParams().get(LAZY_PRINT_TRAIN_INFO_ENABLED)).booleanValue()) {
            String str = (String) getParams().get(LAZY_PRINT_TRAIN_INFO_TITLE);
            LazyEvaluation<Report> genLazyReport = LazyObjectsManager.getLazyObjectsManager(this).genLazyReport(this);
            genLazyReport.addCallback(report -> {
                if (str != null) {
                    System.out.println(str);
                }
                System.out.println(report.toString());
            });
            genLazyReport.addValue(tuning.f1);
        }
        return createModel((TransformerBase) tuning.f0);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public M fit(LocalOperator<?> localOperator) {
        Tuple2<TransformerBase, Report> tuning = tuning(localOperator);
        if (((Boolean) getParams().get(LAZY_PRINT_TRAIN_INFO_ENABLED)).booleanValue()) {
            String str = (String) getParams().get(LAZY_PRINT_TRAIN_INFO_TITLE);
            LazyEvaluation<Report> genLazyReport = LazyObjectsManager.getLazyObjectsManager(this).genLazyReport(this);
            genLazyReport.addCallback(report -> {
                if (str != null) {
                    System.out.println(str);
                }
                System.out.println(report.toString());
            });
            genLazyReport.addValue(tuning.f1);
        }
        return createModel((TransformerBase) tuning.f0);
    }

    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public M fit(StreamOperator<?> streamOperator) {
        throw new AkUnsupportedOperationException("Tuning on stream not supported.");
    }

    private M createModel(TransformerBase transformerBase) {
        try {
            return (M) ((Class) ((ParameterizedType) getClass().getGenericSuperclass()).getActualTypeArguments()[1]).getConstructor(TransformerBase.class).newInstance(transformerBase);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    protected abstract Tuple2<TransformerBase, Report> tuning(BatchOperator<?> batchOperator);

    protected abstract Tuple2<TransformerBase, Report> tuning(LocalOperator<?> localOperator);

    /* JADX INFO: Access modifiers changed from: protected */
    public Tuple2<Pipeline, Report> findBestTVSplit(LocalOperator<?> localOperator, double d, PipelineCandidatesBase pipelineCandidatesBase) {
        int size = pipelineCandidatesBase.size();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList(size);
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        for (Row row : localOperator.getOutputTable().getRows()) {
            if (ThreadLocalRandom.current().nextDouble() <= d) {
                arrayList3.add(row);
            } else {
                arrayList4.add(row);
            }
        }
        MemSourceLocalOp memSourceLocalOp = new MemSourceLocalOp(arrayList3, localOperator.getSchema());
        MemSourceLocalOp memSourceLocalOp2 = new MemSourceLocalOp(arrayList4, localOperator.getSchema());
        if (memSourceLocalOp.getOutputTable().getNumRow() == 0) {
            memSourceLocalOp = memSourceLocalOp2;
        }
        if (memSourceLocalOp2.getOutputTable().getNumRow() == 0) {
            memSourceLocalOp2 = memSourceLocalOp;
        }
        for (int i = 0; i < size; i++) {
            try {
                Tuple2<Pipeline, List<Tuple3<Integer, ParamInfo, Object>>> tuple2 = pipelineCandidatesBase.get(i, arrayList2);
                double d2 = Double.NaN;
                try {
                    Pipeline pipeline = (Pipeline) tuple2.f0;
                    PipelineModel fit = pipeline.fit((LocalOperator<?>) memSourceLocalOp);
                    d2 = this.tuningEvaluator.evaluate(fit.transform(memSourceLocalOp2));
                    double evaluate = this.tuningEvaluator.evaluate(fit.transform(memSourceLocalOp));
                    if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                        System.out.println(String.format("taskId:%d params:%s trainMetric:%f testMetric:%f", Integer.valueOf(i), pipeline.get(pipeline.size() - 1).getParams().toString(), Double.valueOf(evaluate), Double.valueOf(d2)));
                    }
                    arrayList2.add(i, Double.valueOf(d2));
                    if (Double.isNaN(d2)) {
                        arrayList.add(new Report.ReportElement((Pipeline) tuple2.f0, (List) tuple2.f1, Double.valueOf(d2), "Metric is nan."));
                    } else {
                        arrayList.add(new Report.ReportElement((Pipeline) tuple2.f0, (List) tuple2.f1, Double.valueOf(d2)));
                    }
                } catch (Exception e) {
                    if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                        System.out.println(String.format("BestTVSplit, i: %d, metric: %f, exception: %s", Integer.valueOf(i), Double.valueOf(d2), ExceptionUtils.stringifyException(e)));
                    }
                    arrayList2.add(i, Double.valueOf(d2));
                    arrayList.add(new Report.ReportElement((Pipeline) tuple2.f0, (List) tuple2.f1, Double.valueOf(d2), ExceptionUtils.stringifyException(e)));
                }
            } catch (CloneNotSupportedException e2) {
                throw new RuntimeException(e2);
            }
        }
        int i2 = -1;
        double d3 = 0.0d;
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            Report.ReportElement reportElement = (Report.ReportElement) arrayList.get(i3);
            arrayList2.add(i3, reportElement.getMetric());
            if (i2 == -1) {
                d3 = reportElement.getMetric().doubleValue();
                i2 = i3;
            } else if ((this.tuningEvaluator.isLargerBetter() && d3 < reportElement.getMetric().doubleValue()) || (!this.tuningEvaluator.isLargerBetter() && d3 > reportElement.getMetric().doubleValue())) {
                d3 = reportElement.getMetric().doubleValue();
                i2 = i3;
            }
            if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println(String.format("BestTVSplit, i: %d, best: %f, metric: %f", Integer.valueOf(i3), Double.valueOf(d3), reportElement.getMetric()));
            }
        }
        if (i2 < 0) {
            throw new RuntimeException("Can not find a best model. Report: " + new Report(this.tuningEvaluator, arrayList).toPrettyJson());
        }
        try {
            return Tuple2.of(pipelineCandidatesBase.get(i2, arrayList2).f0, new Report(this.tuningEvaluator, arrayList));
        } catch (CloneNotSupportedException e3) {
            throw new RuntimeException(e3);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Tuple2<Pipeline, Report> findBestTVSplit(BatchOperator<?> batchOperator, double d, PipelineCandidatesBase pipelineCandidatesBase) {
        return findBestTVSplit(batchOperator, d, pipelineCandidatesBase, getParallelTuningMode().booleanValue());
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected Tuple2<Pipeline, Report> findBestTVSplit(BatchOperator<?> batchOperator, final double d, PipelineCandidatesBase pipelineCandidatesBase, boolean z) {
        final int size = pipelineCandidatesBase.size();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList(size);
        final Params params = this.tuningEvaluator.getParams();
        final Class<?> cls = this.tuningEvaluator.getClass();
        if (z) {
            final ArrayList arrayList3 = new ArrayList();
            List<Double> arrayList4 = new ArrayList<>();
            ArrayList arrayList5 = new ArrayList();
            for (int i = 0; i < size; i++) {
                arrayList.add(null);
                try {
                    arrayList3.add(((Pipeline) pipelineCandidatesBase.get(i, arrayList4).f0).saveLocal().getOutputTable());
                    arrayList5.add(pipelineCandidatesBase.get(i, arrayList4).f1);
                } catch (CloneNotSupportedException e) {
                    throw new RuntimeException(e);
                }
            }
            TableSchema schema = batchOperator.getSchema();
            final String[] fieldNames = schema.getFieldNames();
            final DataType[] fieldDataTypes = schema.getFieldDataTypes();
            List<Row> list = null;
            try {
                list = batchOperator.rebalance().getDataSet().mapPartition(new RichMapPartitionFunction<Row, Tuple3<Integer, Integer, Row>>() { // from class: com.alibaba.alink.pipeline.tuning.BaseTuning.3
                    public void mapPartition(Iterable<Row> iterable, Collector<Tuple3<Integer, Integer, Row>> collector) throws Exception {
                        int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
                        for (Row row : iterable) {
                            int i2 = ThreadLocalRandom.current().nextDouble() <= d ? 0 : 1;
                            for (int i3 = 0; i3 < numberOfParallelSubtasks; i3++) {
                                collector.collect(Tuple3.of(Integer.valueOf(i3), Integer.valueOf(i2), row));
                            }
                        }
                    }
                }).partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.pipeline.tuning.BaseTuning.2
                    public int partition(Integer num, int i2) {
                        return num.intValue() % i2;
                    }
                }, 0).mapPartition(new RichMapPartitionFunction<Tuple3<Integer, Integer, Row>, Row>() { // from class: com.alibaba.alink.pipeline.tuning.BaseTuning.1
                    public void mapPartition(Iterable<Tuple3<Integer, Integer, Row>> iterable, Collector<Row> collector) throws Exception {
                        ArrayList arrayList6 = new ArrayList();
                        ArrayList arrayList7 = new ArrayList();
                        for (Tuple3<Integer, Integer, Row> tuple3 : iterable) {
                            if (((Integer) tuple3.f1).intValue() == 0) {
                                arrayList6.add(tuple3.f2);
                            } else {
                                arrayList7.add(tuple3.f2);
                            }
                        }
                        int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                        int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
                        DefaultDistributedInfo defaultDistributedInfo = new DefaultDistributedInfo();
                        int startPos = (int) defaultDistributedInfo.startPos(indexOfThisSubtask, numberOfParallelSubtasks, size);
                        int localRowCnt = (int) defaultDistributedInfo.localRowCnt(indexOfThisSubtask, numberOfParallelSubtasks, size);
                        MemSourceLocalOp memSourceLocalOp = new MemSourceLocalOp(arrayList6, TableSchema.builder().fields(fieldNames, fieldDataTypes).build());
                        MemSourceLocalOp memSourceLocalOp2 = new MemSourceLocalOp(arrayList7, TableSchema.builder().fields(fieldNames, fieldDataTypes).build());
                        if (memSourceLocalOp.getOutputTable().getNumRow() == 0) {
                            memSourceLocalOp = memSourceLocalOp2;
                        }
                        if (memSourceLocalOp2.getOutputTable().getNumRow() == 0) {
                            memSourceLocalOp2 = memSourceLocalOp;
                        }
                        for (int i2 = startPos; i2 < startPos + localRowCnt; i2++) {
                            Pipeline loadLocal = Pipeline.loadLocal(new MemSourceLocalOp((MTable) arrayList3.get(i2)));
                            TuningEvaluator tuningEvaluator = (TuningEvaluator) cls.getConstructor(Params.class).newInstance(params);
                            PipelineModel fit = loadLocal.fit((LocalOperator<?>) memSourceLocalOp);
                            collector.collect(Row.of(new Object[]{Integer.valueOf(i2), loadLocal, Double.valueOf(tuningEvaluator.evaluate(fit.transform(memSourceLocalOp))), Double.valueOf(tuningEvaluator.evaluate(fit.transform(memSourceLocalOp2)))}));
                        }
                    }
                }).name("parallel_standalone_build_model").collect();
            } catch (Exception e2) {
                e2.printStackTrace();
            }
            for (Row row : list) {
                Integer num = (Integer) row.getField(0);
                Pipeline pipeline = (Pipeline) row.getField(1);
                Double d2 = (Double) row.getField(3);
                System.out.println(String.format("taskId:%d params:%s trainMetric:%f testMetric:%f", num, pipeline.get(pipeline.size() - 1).getParams().toString(), (Double) row.getField(2), d2));
                arrayList.set(num.intValue(), new Report.ReportElement(pipeline, (List) arrayList5.get(num.intValue()), d2));
            }
        } else {
            SplitBatchOp linkFrom = ((SplitBatchOp) new SplitBatchOp().setFraction(Double.valueOf(d)).setMLEnvironmentId(getMLEnvironmentId())).linkFrom((BatchOperator) new TableSourceBatchOp(DataSetConversionUtil.toTable(batchOperator.getMLEnvironmentId(), shuffle(batchOperator.getDataSet()), batchOperator.getSchema())).setMLEnvironmentId(getMLEnvironmentId()));
            for (int i2 = 0; i2 < size; i2++) {
                try {
                    Tuple2<Pipeline, List<Tuple3<Integer, ParamInfo, Object>>> tuple2 = pipelineCandidatesBase.get(i2, arrayList2);
                    double d3 = Double.NaN;
                    try {
                        Pipeline pipeline2 = (Pipeline) tuple2.f0;
                        PipelineModel fit = pipeline2.fit((BatchOperator<?>) linkFrom);
                        d3 = this.tuningEvaluator.evaluate(fit.transform(linkFrom.getSideOutput(0)));
                        System.out.println(String.format("taskId:%d params:%s trainMetric:%f testMetric:%f", Integer.valueOf(i2), pipeline2.get(pipeline2.size() - 1).getParams().toString(), Double.valueOf(this.tuningEvaluator.evaluate(fit.transform(linkFrom))), Double.valueOf(d3)));
                        arrayList2.add(i2, Double.valueOf(d3));
                        if (Double.isNaN(d3)) {
                            arrayList.add(new Report.ReportElement((Pipeline) tuple2.f0, (List) tuple2.f1, Double.valueOf(d3), "Metric is nan."));
                        } else {
                            arrayList.add(new Report.ReportElement((Pipeline) tuple2.f0, (List) tuple2.f1, Double.valueOf(d3)));
                        }
                    } catch (Exception e3) {
                        if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                            System.out.println(String.format("BestTVSplit, i: %d, metric: %f, exception: %s", Integer.valueOf(i2), Double.valueOf(d3), ExceptionUtils.stringifyException(e3)));
                        }
                        arrayList2.add(i2, Double.valueOf(d3));
                        arrayList.add(new Report.ReportElement((Pipeline) tuple2.f0, (List) tuple2.f1, Double.valueOf(d3), ExceptionUtils.stringifyException(e3)));
                    }
                } catch (CloneNotSupportedException e4) {
                    throw new RuntimeException(e4);
                }
            }
        }
        int i3 = -1;
        double d4 = 0.0d;
        for (int i4 = 0; i4 < arrayList.size(); i4++) {
            Report.ReportElement reportElement = (Report.ReportElement) arrayList.get(i4);
            Pipeline pipeline3 = reportElement.getPipeline();
            arrayList2.add(i4, reportElement.getMetric());
            if (i3 == -1) {
                d4 = reportElement.getMetric().doubleValue();
                i3 = i4;
            } else if ((this.tuningEvaluator.isLargerBetter() && d4 < reportElement.getMetric().doubleValue()) || (!this.tuningEvaluator.isLargerBetter() && d4 > reportElement.getMetric().doubleValue())) {
                d4 = reportElement.getMetric().doubleValue();
                i3 = i4;
            }
            System.out.println(String.format("BestTVSplit, i: %d, params: %s, best: %f, metric: %f", Integer.valueOf(i4), pipeline3.get(pipeline3.size() - 1).getParams().toString(), Double.valueOf(d4), reportElement.getMetric()));
        }
        if (i3 < 0) {
            throw new RuntimeException("Can not find a best model. Report: " + new Report(this.tuningEvaluator, arrayList).toPrettyJson());
        }
        try {
            return Tuple2.of(pipelineCandidatesBase.get(i3, arrayList2).f0, new Report(this.tuningEvaluator, arrayList));
        } catch (CloneNotSupportedException e5) {
            throw new RuntimeException(e5);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Tuple2<Pipeline, Report> findBestCV(LocalOperator<?> localOperator, int i, PipelineCandidatesBase pipelineCandidatesBase) {
        AkPreconditions.checkArgument(i > 1, "numFolds could be greater than 1.");
        List<Tuple2<Integer, Row>> split = split(localOperator, i);
        int size = pipelineCandidatesBase.size();
        Double d = null;
        Integer num = null;
        ArrayList arrayList = new ArrayList(size);
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < size; i2++) {
            try {
                Tuple2<Pipeline, List<Tuple3<Integer, ParamInfo, Object>>> tuple2 = pipelineCandidatesBase.get(i2, arrayList);
                Tuple3<Double, Double, String> kFoldCv = kFoldCv(split, (Pipeline) tuple2.f0, localOperator.getSchema(), i, this.tuningEvaluator);
                arrayList.add(i2, kFoldCv.f1);
                if (Double.isNaN(((Double) kFoldCv.f0).doubleValue())) {
                    if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                        System.out.println(String.format("BestCV, i: %d, best: %f, avg: %f", Integer.valueOf(i2), d, kFoldCv.f1));
                    }
                    arrayList2.add(new Report.ReportElement((Pipeline) tuple2.f0, (List) tuple2.f1, (Double) kFoldCv.f1, (String) kFoldCv.f2));
                } else {
                    arrayList2.add(new Report.ReportElement((Pipeline) tuple2.f0, (List) tuple2.f1, (Double) kFoldCv.f1, (String) kFoldCv.f2));
                    if (d == null) {
                        d = (Double) kFoldCv.f0;
                        num = Integer.valueOf(i2);
                    } else if ((this.tuningEvaluator.isLargerBetter() && d.doubleValue() < ((Double) kFoldCv.f0).doubleValue()) || (!this.tuningEvaluator.isLargerBetter() && d.doubleValue() > ((Double) kFoldCv.f0).doubleValue())) {
                        d = (Double) kFoldCv.f0;
                        num = Integer.valueOf(i2);
                    }
                    if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                        System.out.println(String.format("BestCV, i: %d, best: %f, avg: %f", Integer.valueOf(i2), d, kFoldCv.f0));
                    }
                }
            } catch (CloneNotSupportedException e) {
                throw new RuntimeException(e);
            }
        }
        if (num == null) {
            throw new RuntimeException("Can not find a best model. Report: " + new Report(this.tuningEvaluator, arrayList2).toPrettyJson());
        }
        try {
            return Tuple2.of(pipelineCandidatesBase.get(num.intValue(), arrayList).f0, new Report(this.tuningEvaluator, arrayList2));
        } catch (CloneNotSupportedException e2) {
            throw new RuntimeException(e2);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Tuple2<Pipeline, Report> findBestCV(BatchOperator<?> batchOperator, int i, PipelineCandidatesBase pipelineCandidatesBase) {
        AkPreconditions.checkArgument(i > 1, "numFolds could be greater than 1.");
        if (getParallelTuningMode().booleanValue()) {
            return findBestCVOptimizeMode(batchOperator, i, pipelineCandidatesBase);
        }
        DataSet<Tuple2<Integer, Row>> split = split(batchOperator, i);
        int size = pipelineCandidatesBase.size();
        Double d = null;
        Integer num = null;
        ArrayList arrayList = new ArrayList(size);
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < size; i2++) {
            try {
                Tuple2<Pipeline, List<Tuple3<Integer, ParamInfo, Object>>> tuple2 = pipelineCandidatesBase.get(i2, arrayList);
                Tuple2<Double, String> kFoldCv = kFoldCv(split, (Pipeline) tuple2.f0, batchOperator.getSchema(), i);
                arrayList.add(i2, kFoldCv.f0);
                if (Double.isNaN(((Double) kFoldCv.f0).doubleValue())) {
                    if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                        System.out.println(String.format("BestCV, i: %d, best: %f, avg: %f", Integer.valueOf(i2), d, kFoldCv.f0));
                    }
                    arrayList2.add(new Report.ReportElement((Pipeline) tuple2.f0, (List) tuple2.f1, (Double) kFoldCv.f0, (String) kFoldCv.f1));
                } else {
                    arrayList2.add(new Report.ReportElement((Pipeline) tuple2.f0, (List) tuple2.f1, (Double) kFoldCv.f0, (String) kFoldCv.f1));
                    if (d == null) {
                        d = (Double) kFoldCv.f0;
                        num = Integer.valueOf(i2);
                    } else if ((this.tuningEvaluator.isLargerBetter() && d.doubleValue() < ((Double) kFoldCv.f0).doubleValue()) || (!this.tuningEvaluator.isLargerBetter() && d.doubleValue() > ((Double) kFoldCv.f0).doubleValue())) {
                        d = (Double) kFoldCv.f0;
                        num = Integer.valueOf(i2);
                    }
                    if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                        System.out.println(String.format("BestCV, i: %d, best: %f, avg: %f", Integer.valueOf(i2), d, kFoldCv.f0));
                    }
                }
            } catch (CloneNotSupportedException e) {
                throw new RuntimeException(e);
            }
        }
        if (num == null) {
            throw new RuntimeException("Can not find a best model. Report: " + new Report(this.tuningEvaluator, arrayList2).toPrettyJson());
        }
        try {
            return Tuple2.of(pipelineCandidatesBase.get(num.intValue(), arrayList).f0, new Report(this.tuningEvaluator, arrayList2));
        } catch (CloneNotSupportedException e2) {
            throw new RuntimeException(e2);
        }
    }

    protected Tuple2<Pipeline, Report> findBestCVOptimizeMode(BatchOperator<?> batchOperator, final int i, PipelineCandidatesBase pipelineCandidatesBase) {
        DataSet<Tuple2<Integer, Row>> split = split(batchOperator, i);
        final int size = pipelineCandidatesBase.size();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList(size);
        final Params params = this.tuningEvaluator.getParams();
        final Class<?> cls = this.tuningEvaluator.getClass();
        final ArrayList arrayList3 = new ArrayList();
        List<Double> arrayList4 = new ArrayList<>();
        ArrayList arrayList5 = new ArrayList();
        for (int i2 = 0; i2 < size; i2++) {
            arrayList.add(null);
            try {
                arrayList3.add(((Pipeline) pipelineCandidatesBase.get(i2, arrayList4).f0).saveLocal().getOutputTable());
                arrayList5.add(pipelineCandidatesBase.get(i2, arrayList4).f1);
            } catch (CloneNotSupportedException e) {
                throw new RuntimeException(e);
            }
        }
        TableSchema schema = batchOperator.getSchema();
        final String[] fieldNames = schema.getFieldNames();
        final DataType[] fieldDataTypes = schema.getFieldDataTypes();
        List<Row> list = null;
        try {
            list = split.rebalance().mapPartition(new RichMapPartitionFunction<Tuple2<Integer, Row>, Tuple3<Integer, Integer, Row>>() { // from class: com.alibaba.alink.pipeline.tuning.BaseTuning.6
                public void mapPartition(Iterable<Tuple2<Integer, Row>> iterable, Collector<Tuple3<Integer, Integer, Row>> collector) throws Exception {
                    int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
                    for (Tuple2<Integer, Row> tuple2 : iterable) {
                        for (int i3 = 0; i3 < numberOfParallelSubtasks; i3++) {
                            collector.collect(Tuple3.of(Integer.valueOf(i3), tuple2.f0, tuple2.f1));
                        }
                    }
                }
            }).partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.pipeline.tuning.BaseTuning.5
                public int partition(Integer num, int i3) {
                    return num.intValue() % i3;
                }
            }, 0).mapPartition(new RichMapPartitionFunction<Tuple3<Integer, Integer, Row>, Row>() { // from class: com.alibaba.alink.pipeline.tuning.BaseTuning.4
                public void mapPartition(Iterable<Tuple3<Integer, Integer, Row>> iterable, Collector<Row> collector) throws Exception {
                    ArrayList arrayList6 = new ArrayList();
                    for (Tuple3<Integer, Integer, Row> tuple3 : iterable) {
                        arrayList6.add(Tuple2.of(tuple3.f1, tuple3.f2));
                    }
                    int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                    int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
                    DefaultDistributedInfo defaultDistributedInfo = new DefaultDistributedInfo();
                    int startPos = (int) defaultDistributedInfo.startPos(indexOfThisSubtask, numberOfParallelSubtasks, size);
                    int localRowCnt = (int) defaultDistributedInfo.localRowCnt(indexOfThisSubtask, numberOfParallelSubtasks, size);
                    TableSchema build = TableSchema.builder().fields(fieldNames, fieldDataTypes).build();
                    for (int i3 = startPos; i3 < startPos + localRowCnt; i3++) {
                        Pipeline loadLocal = Pipeline.loadLocal(new MemSourceLocalOp((MTable) arrayList3.get(i3)));
                        Tuple3 kFoldCv = BaseTuning.kFoldCv(arrayList6, loadLocal, build, i, (TuningEvaluator) cls.getConstructor(Params.class).newInstance(params));
                        collector.collect(Row.of(new Object[]{Integer.valueOf(i3), loadLocal, kFoldCv.f0, kFoldCv.f1}));
                    }
                }
            }).name("parallel_standalone_build_model").collect();
        } catch (Exception e2) {
            e2.printStackTrace();
        }
        for (Row row : list) {
            Integer num = (Integer) row.getField(0);
            Pipeline pipeline = (Pipeline) row.getField(1);
            Double d = (Double) row.getField(2);
            Double d2 = (Double) row.getField(3);
            System.out.println(String.format("kFoldCv, i: %d, params: %s, trainMetric: %f, valMetric: %f", num, pipeline.get(pipeline.size() - 1).getParams().toString(), d, d2));
            arrayList.set(num.intValue(), new Report.ReportElement(pipeline, (List) arrayList5.get(num.intValue()), d2));
        }
        int i3 = -1;
        double d3 = 0.0d;
        for (int i4 = 0; i4 < arrayList.size(); i4++) {
            Report.ReportElement reportElement = (Report.ReportElement) arrayList.get(i4);
            Pipeline pipeline2 = reportElement.getPipeline();
            arrayList2.add(i4, reportElement.getMetric());
            if (i3 == -1) {
                d3 = reportElement.getMetric().doubleValue();
                i3 = i4;
            } else if ((this.tuningEvaluator.isLargerBetter() && d3 < reportElement.getMetric().doubleValue()) || (!this.tuningEvaluator.isLargerBetter() && d3 > reportElement.getMetric().doubleValue())) {
                d3 = reportElement.getMetric().doubleValue();
                i3 = i4;
            }
            if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println(String.format("BestTVSplit, i: %d, params: %s, best: %f, metric: %f", Integer.valueOf(i4), pipeline2.get(pipeline2.size() - 1).getParams().toString(), Double.valueOf(d3), reportElement.getMetric()));
            }
        }
        if (i3 < 0) {
            throw new RuntimeException("Can not find a best model. Report: " + new Report(this.tuningEvaluator, arrayList).toPrettyJson());
        }
        try {
            return Tuple2.of(pipelineCandidatesBase.get(i3, arrayList2).f0, new Report(this.tuningEvaluator, arrayList));
        } catch (CloneNotSupportedException e3) {
            throw new RuntimeException(e3);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Tuple2<Double, String> kFoldCv(DataSet<Tuple2<Integer, Row>> dataSet, Pipeline pipeline, TableSchema tableSchema, int i) {
        double d = 0.0d;
        int i2 = 0;
        StringBuilder sb = new StringBuilder();
        for (int i3 = 0; i3 < i; i3++) {
            final int i4 = i3;
            double d2 = Double.NaN;
            try {
                d2 = this.tuningEvaluator.evaluate(pipeline.fit((BatchOperator<?>) new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) dataSet.filter(new FilterFunction<Tuple2<Integer, Row>>() { // from class: com.alibaba.alink.pipeline.tuning.BaseTuning.8
                    private static final long serialVersionUID = 2249884521437544236L;

                    public boolean filter(Tuple2<Integer, Row> tuple2) {
                        return ((Integer) tuple2.f0).intValue() != i4;
                    }
                }).map(new MapFunction<Tuple2<Integer, Row>, Row>() { // from class: com.alibaba.alink.pipeline.tuning.BaseTuning.7
                    private static final long serialVersionUID = 2618229645786221757L;

                    public Row map(Tuple2<Integer, Row> tuple2) {
                        return (Row) tuple2.f1;
                    }
                }), tableSchema)).setMLEnvironmentId(getMLEnvironmentId())).transform(new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) dataSet.filter(new FilterFunction<Tuple2<Integer, Row>>() { // from class: com.alibaba.alink.pipeline.tuning.BaseTuning.10
                    private static final long serialVersionUID = 5811166054549336470L;

                    public boolean filter(Tuple2<Integer, Row> tuple2) {
                        return ((Integer) tuple2.f0).intValue() == i4;
                    }
                }).map(new MapFunction<Tuple2<Integer, Row>, Row>() { // from class: com.alibaba.alink.pipeline.tuning.BaseTuning.9
                    private static final long serialVersionUID = -1760709990316111721L;

                    public Row map(Tuple2<Integer, Row> tuple2) {
                        return (Row) tuple2.f1;
                    }
                }), tableSchema))));
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.println(String.format("kFoldCv, k: %d, i: %d, metric: %f", Integer.valueOf(i), Integer.valueOf(i3), Double.valueOf(d2)));
                }
                d += d2;
                i2++;
            } catch (Exception e) {
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.println(String.format("kFoldCv err, k: %d, i: %d, metric: %f, exception: %s", Integer.valueOf(i), Integer.valueOf(i3), Double.valueOf(d2), ExceptionUtils.stringifyException(e)));
                }
                sb.append(ExceptionUtils.stringifyException(e)).append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
            }
        }
        if (i2 == 0) {
            sb.append("valid size is zero.").append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
            return Tuple2.of(Double.valueOf(Double.NaN), sb.toString());
        }
        double d3 = d / i2;
        if (i2 > 0) {
            return Tuple2.of(Double.valueOf(d3), sb.toString());
        }
        sb.append("valid size if negative.").append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
        return Tuple2.of(Double.valueOf(Double.NaN), sb.toString());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Tuple3<Double, Double, String> kFoldCv(List<Tuple2<Integer, Row>> list, Pipeline pipeline, TableSchema tableSchema, int i, TuningEvaluator tuningEvaluator) {
        double d = 0.0d;
        double d2 = 0.0d;
        int i2 = 0;
        StringBuilder sb = new StringBuilder();
        for (int i3 = 0; i3 < i; i3++) {
            int i4 = i3;
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (Tuple2<Integer, Row> tuple2 : list) {
                if (((Integer) tuple2.f0).intValue() == i4) {
                    arrayList.add(tuple2.f1);
                } else {
                    arrayList2.add(tuple2.f1);
                }
            }
            MemSourceLocalOp memSourceLocalOp = new MemSourceLocalOp(arrayList, tableSchema);
            MemSourceLocalOp memSourceLocalOp2 = new MemSourceLocalOp(arrayList2, tableSchema);
            PipelineModel fit = pipeline.fit((LocalOperator<?>) memSourceLocalOp);
            double d3 = Double.NaN;
            try {
                d3 = tuningEvaluator.evaluate(fit.transform(memSourceLocalOp2));
                double evaluate = tuningEvaluator.evaluate(fit.transform(memSourceLocalOp));
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.println(String.format("kFoldCv, k: %d, i: %d, metric: %f", Integer.valueOf(i), Integer.valueOf(i3), Double.valueOf(d3)));
                }
                d += d3;
                d2 += evaluate;
                i2++;
            } catch (Exception e) {
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.println(String.format("kFoldCv err, k: %d, i: %d, metric: %f, exception: %s", Integer.valueOf(i), Integer.valueOf(i3), Double.valueOf(d3), ExceptionUtils.stringifyException(e)));
                }
                sb.append(ExceptionUtils.stringifyException(e)).append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
            }
        }
        if (i2 == 0) {
            sb.append("valid size is zero.").append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
            return Tuple3.of(Double.valueOf(Double.NaN), Double.valueOf(Double.NaN), sb.toString());
        }
        double d4 = d / i2;
        double d5 = d2 / i2;
        if (i2 > 0) {
            return Tuple3.of(Double.valueOf(d5), Double.valueOf(d4), sb.toString());
        }
        sb.append("valid size if negative.").append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
        return Tuple3.of(Double.valueOf(Double.NaN), Double.valueOf(Double.NaN), sb.toString());
    }

    private DataSet<Row> shuffle(DataSet<Row> dataSet) {
        return dataSet.map(new MapFunction<Row, Tuple2<Long, Row>>() { // from class: com.alibaba.alink.pipeline.tuning.BaseTuning.13
            private static final long serialVersionUID = 2565906511879493627L;

            public Tuple2<Long, Row> map(Row row) {
                return Tuple2.of(Long.valueOf(ThreadLocalRandom.current().nextLong(Long.MAX_VALUE)), row);
            }
        }).partitionCustom(new Partitioner<Long>() { // from class: com.alibaba.alink.pipeline.tuning.BaseTuning.12
            private static final long serialVersionUID = 8626504946902766931L;

            public int partition(Long l, int i) {
                return (int) (l.longValue() % i);
            }
        }, 0).sortPartition(0, Order.ASCENDING).map(new MapFunction<Tuple2<Long, Row>, Row>() { // from class: com.alibaba.alink.pipeline.tuning.BaseTuning.11
            private static final long serialVersionUID = 2667225910228407097L;

            public Row map(Tuple2<Long, Row> tuple2) {
                return (Row) tuple2.f1;
            }
        });
    }

    private DataSet<Tuple2<Integer, Row>> split(BatchOperator<?> batchOperator, final int i) {
        DataSet<Row> shuffle = shuffle(batchOperator.getDataSet());
        return shuffle.mapPartition(new RichMapPartitionFunction<Row, Tuple2<Integer, Row>>() { // from class: com.alibaba.alink.pipeline.tuning.BaseTuning.14
            private static final long serialVersionUID = -902599228310615694L;
            long taskStart = 0;
            long totalNumInstance = 0;

            public void open(Configuration configuration) {
                List<Tuple2> broadcastVariable = getRuntimeContext().getBroadcastVariable("counts");
                int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                for (Tuple2 tuple2 : broadcastVariable) {
                    if (indexOfThisSubtask < ((Integer) tuple2.f0).intValue()) {
                        this.taskStart += ((Long) tuple2.f1).longValue();
                    }
                    this.totalNumInstance += ((Long) tuple2.f1).longValue();
                }
            }

            public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Integer, Row>> collector) {
                DefaultDistributedInfo defaultDistributedInfo = new DefaultDistributedInfo();
                Tuple2 tuple2 = new Tuple2(-1, -1L);
                long j = this.taskStart;
                int i2 = 0;
                while (true) {
                    if (i2 > i) {
                        break;
                    }
                    long startPos = defaultDistributedInfo.startPos(i2, i, this.totalNumInstance);
                    long localRowCnt = defaultDistributedInfo.localRowCnt(i2, i, this.totalNumInstance);
                    if (this.taskStart < startPos) {
                        tuple2.f0 = Integer.valueOf(i2 - 1);
                        tuple2.f1 = Long.valueOf(defaultDistributedInfo.startPos(i2 - 1, i, this.totalNumInstance) + defaultDistributedInfo.localRowCnt(i2 - 1, i, this.totalNumInstance));
                        break;
                    } else {
                        if (this.taskStart == startPos) {
                            tuple2.f0 = Integer.valueOf(i2);
                            tuple2.f1 = Long.valueOf(startPos + localRowCnt);
                            break;
                        }
                        i2++;
                    }
                }
                for (Row row : iterable) {
                    if (j >= ((Long) tuple2.f1).longValue()) {
                        tuple2.f0 = Integer.valueOf(((Integer) tuple2.f0).intValue() + 1);
                        tuple2.f1 = Long.valueOf(defaultDistributedInfo.localRowCnt(((Integer) tuple2.f0).intValue(), i, this.totalNumInstance) + j);
                    }
                    collector.collect(Tuple2.of(tuple2.f0, row));
                    j++;
                }
            }
        }).withBroadcastSet(DataSetUtils.countElementsPerPartition(shuffle), "counts");
    }

    private List<Tuple2<Integer, Row>> split(LocalOperator<?> localOperator, int i) {
        ArrayList arrayList = new ArrayList();
        Iterator<Row> it = localOperator.getOutputTable().getRows().iterator();
        while (it.hasNext()) {
            arrayList.add(Tuple2.of(Integer.valueOf(ThreadLocalRandom.current().nextInt(i)), it.next()));
        }
        return arrayList;
    }

    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public /* bridge */ /* synthetic */ ModelBase fit(LocalOperator localOperator) {
        return fit((LocalOperator<?>) localOperator);
    }

    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public /* bridge */ /* synthetic */ ModelBase fit(StreamOperator streamOperator) {
        return fit((StreamOperator<?>) streamOperator);
    }

    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public /* bridge */ /* synthetic */ ModelBase fit(BatchOperator batchOperator) {
        return fit((BatchOperator<?>) batchOperator);
    }
}
