package com.alibaba.alink.pipeline.tuning;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.params.tuning.BayesTuningParams;
import com.alibaba.alink.pipeline.EstimatorBase;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineStageBase;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple5;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/pipeline/tuning/PipelineCandidatesBayes.class */
public class PipelineCandidatesBayes extends PipelineCandidatesBase {
    transient ArrayList<Tuple3<Integer, ParamInfo, ValueDist>> items;
    final ParamDist paramDist;
    final long seed;
    final int warmUpJobs;
    final int nIter;
    final int sampleNum;
    final int linearForgetting;
    final int dim;
    final BayesTuningParams.BayesStrategy strategy;
    final boolean isLargerBetter;
    Random rand;
    private final ArrayList<Object[]> records;
    private final ArrayList<DenseVector> recordVectors;

    public PipelineCandidatesBayes(EstimatorBase estimatorBase, ParamDist paramDist, long j, int i, int i2, int i3, int i4, BayesTuningParams.BayesStrategy bayesStrategy, boolean z) {
        super(estimatorBase);
        this.rand = new Random();
        this.paramDist = paramDist;
        this.seed = j;
        this.warmUpJobs = i;
        this.nIter = i2;
        this.sampleNum = i3;
        this.linearForgetting = i4;
        this.strategy = bayesStrategy;
        this.isLargerBetter = z;
        this.records = new ArrayList<>(this.nIter);
        this.recordVectors = new ArrayList<>(this.nIter);
        setItems(paramDist);
        this.dim = paramDist.getItems().size();
    }

    private void setItems(ParamDist paramDist) {
        PipelineStageBase[] pipelineStageBaseArr = new PipelineStageBase[this.dim];
        List<Tuple3<PipelineStageBase, ParamInfo, ValueDist>> items = paramDist.getItems();
        for (int i = 0; i < this.dim; i++) {
            pipelineStageBaseArr[i] = (PipelineStageBase) items.get(i).f0;
        }
        int[] findStageIndex = findStageIndex(this.pipeline, pipelineStageBaseArr);
        this.items = new ArrayList<>();
        for (int i2 = 0; i2 < this.dim; i2++) {
            Tuple3<PipelineStageBase, ParamInfo, ValueDist> tuple3 = items.get(i2);
            this.items.add(new Tuple3<>(Integer.valueOf(findStageIndex[i2]), tuple3.f1, tuple3.f2));
        }
    }

    private static List<Double> tpeProcessExperienceScores(List<Double> list, boolean z) {
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<Double> it = list.iterator();
        while (it.hasNext()) {
            double doubleValue = it.next().doubleValue();
            if (z) {
                arrayList.add(Double.valueOf(-doubleValue));
            } else {
                arrayList.add(Double.valueOf(doubleValue));
            }
        }
        return arrayList;
    }

    @Override // com.alibaba.alink.pipeline.tuning.PipelineCandidatesBase
    public int size() {
        return this.nIter;
    }

    private void fillDoubleArray(Object[] objArr, double[] dArr) {
        for (int i = 0; i < objArr.length; i++) {
            dArr[i] = ((Number) objArr[i]).doubleValue();
        }
    }

    @Override // com.alibaba.alink.pipeline.tuning.PipelineCandidatesBase
    public Tuple2<Pipeline, List<Tuple3<Integer, ParamInfo, Object>>> get(int i, List<Double> list) throws CloneNotSupportedException {
        ArrayList arrayList = new ArrayList();
        if (this.items == null || this.items.size() == 0) {
            setItems(this.paramDist);
        }
        if (i < this.records.size()) {
            Object[] objArr = this.records.get(i);
            for (int i2 = 0; i2 < this.dim; i2++) {
                Tuple3<Integer, ParamInfo, ValueDist> tuple3 = this.items.get(i2);
                arrayList.add(new Tuple3(tuple3.f0, tuple3.f1, objArr[i2]));
            }
        } else {
            this.rand.setSeed(this.seed + (i * 100000));
            Object[] objArr2 = new Object[this.dim];
            double[] dArr = new double[this.dim];
            if (list.size() < this.warmUpJobs) {
                for (int i3 = 0; i3 < this.dim; i3++) {
                    objArr2[i3] = ((ValueDist) this.items.get(i3).f2).get(this.rand.nextDouble());
                }
            } else if (this.strategy.equals(BayesTuningParams.BayesStrategy.GP)) {
                GaussianProcessRegression gaussianProcessRegression = new GaussianProcessRegression(this.recordVectors, list);
                double d = Double.NEGATIVE_INFINITY;
                Object[] objArr3 = new Object[this.dim];
                for (int i4 = 0; i4 < this.sampleNum; i4++) {
                    for (int i5 = 0; i5 < this.dim; i5++) {
                        objArr3[i5] = ((ValueDist) this.items.get(i5).f2).get(this.rand.nextDouble());
                    }
                    fillDoubleArray(objArr3, dArr);
                    double doubleValue = ((Double) gaussianProcessRegression.calc(new DenseVector(dArr)).f0).doubleValue();
                    double d2 = this.isLargerBetter ? doubleValue : -doubleValue;
                    if (d2 > d) {
                        d = d2;
                        Object[] objArr4 = objArr2;
                        objArr2 = objArr3;
                        objArr3 = objArr4;
                    }
                }
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.printf("GP estimates maximum regression: %f, hyper-parameters values: \n", Double.valueOf(d));
                    for (Object obj : objArr2) {
                        System.out.printf("%f ", Double.valueOf(((Number) obj).doubleValue()));
                    }
                    System.out.println();
                }
            } else if (this.strategy.equals(BayesTuningParams.BayesStrategy.TPE)) {
                TreeParzenEstimator treeParzenEstimator = new TreeParzenEstimator(this.recordVectors, tpeProcessExperienceScores(list, this.isLargerBetter));
                for (int i6 = 0; i6 < this.dim; i6++) {
                    Object[] objArr5 = new Object[this.sampleNum];
                    double[] dArr2 = new double[this.sampleNum];
                    for (int i7 = 0; i7 < this.sampleNum; i7++) {
                        objArr5[i7] = ((ValueDist) this.items.get(i6).f2).get(this.rand.nextDouble());
                    }
                    fillDoubleArray(objArr5, dArr2);
                    Tuple5<Double, Double, Double, Double, Integer> valueDistSummary = ValueDistUtils.getValueDistSummary((ValueDist) this.items.get(i6).f2);
                    objArr2[i6] = objArr5[treeParzenEstimator.calc(new DenseVector(dArr2), i6, (Double) valueDistSummary.f0, (Double) valueDistSummary.f1, (Double) valueDistSummary.f2, (Double) valueDistSummary.f3, (Integer) valueDistSummary.f4, this.linearForgetting)];
                }
            }
            this.records.add(objArr2);
            fillDoubleArray(objArr2, dArr);
            this.recordVectors.add(new DenseVector(dArr));
            for (int i8 = 0; i8 < this.dim; i8++) {
                Tuple3<Integer, ParamInfo, ValueDist> tuple32 = this.items.get(i8);
                arrayList.add(new Tuple3(tuple32.f0, tuple32.f1, objArr2[i8]));
            }
        }
        Pipeline pipeline = new Pipeline();
        for (int i9 = 0; i9 < this.pipeline.size(); i9++) {
            try {
                pipeline.add((PipelineStageBase) this.pipeline.get(i9).getClass().getConstructor(Params.class).newInstance(this.pipelineAllStageParams.get(i9)));
            } catch (Exception e) {
            }
        }
        updatePipelineParams(pipeline, arrayList);
        return Tuple2.of(pipeline, arrayList);
    }

    public void checkParamsSameValueOrNullValue(int i, List<Tuple3<Integer, ParamInfo, Object>> list) {
        if (this.items == null || this.items.size() == 0) {
            setItems(this.paramDist);
        }
        if (i < this.records.size() && this.records.get(i).length > 0) {
            Object[] objArr = this.records.get(i);
            for (int i2 = 0; i2 < objArr.length; i2++) {
                if (Math.abs(((Number) objArr[i2]).doubleValue() - ((Number) list.get(i2).f2).doubleValue()) > 1.0E-9d) {
                    throw new RuntimeException("checkParamsSameValueOrNullValue failed.");
                }
            }
            return;
        }
        Object[] objArr2 = new Object[list.size()];
        for (int i3 = 0; i3 < list.size(); i3++) {
            objArr2[i3] = list.get(i3).f2;
        }
        while (this.records.size() <= i) {
            this.records.add(new Object[0]);
        }
        this.records.set(i, objArr2);
    }
}
