package com.alibaba.alink.pipeline.tuning;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.io.directreader.DefaultDistributedInfo;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.local.LocalOperator;
import com.alibaba.alink.operator.local.source.MemSourceLocalOp;
import com.alibaba.alink.params.tuning.HasTrainRatio;
import com.alibaba.alink.pipeline.EstimatorBase;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.tuning.Report;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.types.DataType;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@NameCn("Bayes搜索TV")
/* loaded from: input_file:com/alibaba/alink/pipeline/tuning/BayesSearchTVSplit.class */
public class BayesSearchTVSplit extends BaseBayesSearch<BayesSearchTVSplit, BayesSearchTVSplitModel> implements HasTrainRatio<BayesSearchTVSplit> {
    private static final long serialVersionUID = 5726020877861859001L;

    @Override // com.alibaba.alink.pipeline.tuning.BaseBayesSearch
    protected Tuple2<Pipeline, Report> findBest(BatchOperator<?> batchOperator, final PipelineCandidatesBayes pipelineCandidatesBayes) {
        final int size = pipelineCandidatesBayes.size();
        final double doubleValue = getTrainRatio().doubleValue();
        TuningEvaluator<?> tuningEvaluator = getTuningEvaluator();
        EstimatorBase<?, ?> estimator = getEstimator();
        final Params params = tuningEvaluator.getParams();
        final Class<?> cls = tuningEvaluator.getClass();
        estimator.getParams();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList(size);
        if (!getParallelTuningMode().booleanValue()) {
            return findBestTVSplit(batchOperator, getTrainRatio().doubleValue(), pipelineCandidatesBayes);
        }
        for (int i = 0; i < size; i++) {
            arrayList.add(null);
        }
        TableSchema schema = batchOperator.getSchema();
        final String[] fieldNames = schema.getFieldNames();
        final DataType[] fieldDataTypes = schema.getFieldDataTypes();
        List<Row> list = null;
        try {
            list = batchOperator.rebalance().getDataSet().mapPartition(new RichMapPartitionFunction<Row, Tuple3<Integer, Integer, Row>>() { // from class: com.alibaba.alink.pipeline.tuning.BayesSearchTVSplit.3
                public void mapPartition(Iterable<Row> iterable, Collector<Tuple3<Integer, Integer, Row>> collector) throws Exception {
                    int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
                    new Random();
                    for (Row row : iterable) {
                        for (int i2 = 0; i2 < numberOfParallelSubtasks; i2++) {
                            if (ThreadLocalRandom.current().nextDouble() <= doubleValue) {
                                collector.collect(Tuple3.of(Integer.valueOf(i2), 0, row));
                            } else {
                                collector.collect(Tuple3.of(Integer.valueOf(i2), 1, row));
                            }
                        }
                    }
                }
            }).partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.pipeline.tuning.BayesSearchTVSplit.2
                public int partition(Integer num, int i2) {
                    return num.intValue() % i2;
                }
            }, 0).mapPartition(new RichMapPartitionFunction<Tuple3<Integer, Integer, Row>, Row>() { // from class: com.alibaba.alink.pipeline.tuning.BayesSearchTVSplit.1
                public void mapPartition(Iterable<Tuple3<Integer, Integer, Row>> iterable, Collector<Row> collector) throws Exception {
                    ArrayList arrayList3 = new ArrayList();
                    ArrayList arrayList4 = new ArrayList();
                    for (Tuple3<Integer, Integer, Row> tuple3 : iterable) {
                        if (((Integer) tuple3.f1).intValue() == 0) {
                            arrayList3.add(tuple3.f2);
                        } else {
                            arrayList4.add(tuple3.f2);
                        }
                    }
                    int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                    int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
                    DefaultDistributedInfo defaultDistributedInfo = new DefaultDistributedInfo();
                    int startPos = (int) defaultDistributedInfo.startPos(indexOfThisSubtask, numberOfParallelSubtasks, size);
                    int localRowCnt = (int) defaultDistributedInfo.localRowCnt(indexOfThisSubtask, numberOfParallelSubtasks, size);
                    MemSourceLocalOp memSourceLocalOp = new MemSourceLocalOp(arrayList3, TableSchema.builder().fields(fieldNames, fieldDataTypes).build());
                    MemSourceLocalOp memSourceLocalOp2 = new MemSourceLocalOp(arrayList4, TableSchema.builder().fields(fieldNames, fieldDataTypes).build());
                    if (memSourceLocalOp.getOutputTable().getNumRow() == 0) {
                        memSourceLocalOp = memSourceLocalOp2;
                    }
                    if (memSourceLocalOp2.getOutputTable().getNumRow() == 0) {
                        memSourceLocalOp2 = memSourceLocalOp;
                    }
                    ArrayList arrayList5 = new ArrayList();
                    for (int i2 = startPos; i2 < startPos + localRowCnt; i2++) {
                        Tuple2<Pipeline, List<Tuple3<Integer, ParamInfo, Object>>> tuple2 = pipelineCandidatesBayes.get(i2, arrayList5);
                        Pipeline pipeline = (Pipeline) tuple2.f0;
                        TuningEvaluator tuningEvaluator2 = (TuningEvaluator) cls.getConstructor(Params.class).newInstance(params);
                        PipelineModel fit = pipeline.fit((LocalOperator<?>) memSourceLocalOp);
                        double evaluate = tuningEvaluator2.evaluate(fit.transform(memSourceLocalOp2));
                        double evaluate2 = tuningEvaluator2.evaluate(fit.transform(memSourceLocalOp));
                        arrayList5.add(Double.valueOf(evaluate));
                        if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                            System.out.println(String.format("Iter %d in group %d, taskId:%d params:%s trainMetric:%f testMetric:%f", Integer.valueOf(i2 - startPos), Integer.valueOf(indexOfThisSubtask), Integer.valueOf(i2), pipeline.get(pipeline.size() - 1).getParams().toString(), Double.valueOf(evaluate2), Double.valueOf(evaluate)));
                        }
                        collector.collect(Row.of(new Object[]{Integer.valueOf(i2), Integer.valueOf(indexOfThisSubtask), pipeline, Double.valueOf(evaluate), tuple2.f1}));
                    }
                }
            }).name("parallel_standalone_build_model").collect();
        } catch (Exception e) {
            e.printStackTrace();
        }
        for (Row row : list) {
            Integer num = (Integer) row.getField(0);
            Pipeline pipeline = (Pipeline) row.getField(2);
            Double d = (Double) row.getField(3);
            List<Tuple3<Integer, ParamInfo, Object>> list2 = (List) row.getField(4);
            pipelineCandidatesBayes.checkParamsSameValueOrNullValue(num.intValue(), list2);
            arrayList.set(num.intValue(), new Report.ReportElement(pipeline, list2, d));
        }
        int i2 = -1;
        double d2 = 0.0d;
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            Report.ReportElement reportElement = (Report.ReportElement) arrayList.get(i3);
            arrayList2.add(i3, reportElement.getMetric());
            if (i2 == -1) {
                d2 = reportElement.getMetric().doubleValue();
                i2 = i3;
            } else if ((tuningEvaluator.isLargerBetter() && d2 < reportElement.getMetric().doubleValue()) || (!tuningEvaluator.isLargerBetter() && d2 > reportElement.getMetric().doubleValue())) {
                d2 = reportElement.getMetric().doubleValue();
                i2 = i3;
            }
            if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println(String.format("BestTVSplit, i: %d, best: %f, metric: %f", Integer.valueOf(i3), Double.valueOf(d2), reportElement.getMetric()));
            }
        }
        if (i2 < 0) {
            throw new RuntimeException("Can not find a best model. Report: " + new Report(tuningEvaluator, arrayList).toPrettyJson());
        }
        try {
            return Tuple2.of(pipelineCandidatesBayes.get(i2, arrayList2).f0, new Report(tuningEvaluator, arrayList));
        } catch (CloneNotSupportedException e2) {
            throw new RuntimeException(e2);
        }
    }

    @Override // com.alibaba.alink.pipeline.tuning.BaseBayesSearch
    protected Tuple2<Pipeline, Report> findBest(LocalOperator<?> localOperator, PipelineCandidatesBayes pipelineCandidatesBayes) {
        return findBestTVSplit(localOperator, getTrainRatio().doubleValue(), pipelineCandidatesBayes);
    }
}
