package com.alibaba.alink.operator.common.tree;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.FeatureColsVectorColMutexRule;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComQueue;
import com.alibaba.alink.common.comqueue.CompleteResultFunction;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.comqueue.IterTaskObjKeeper;
import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.exceptions.AkIllegalStateException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.XGboostException;
import com.alibaba.alink.common.io.plugin.TemporaryClassLoaderContext;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.clustering.KMeansTrainBatchOp;
import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable;
import com.alibaba.alink.operator.common.outlier.OutlierUtil;
import com.alibaba.alink.operator.common.tree.BaseXGBoostTrainBatchOp;
import com.alibaba.alink.operator.common.tree.xgboost.Booster;
import com.alibaba.alink.operator.common.tree.xgboost.Tracker;
import com.alibaba.alink.operator.common.tree.xgboost.XGBoost;
import com.alibaba.alink.operator.common.tree.xgboost.plugin.XGBoostClassLoaderFactory;
import com.alibaba.alink.params.xgboost.HasObjective;
import com.alibaba.alink.params.xgboost.XGBoostDebugParams;
import com.alibaba.alink.params.xgboost.XGBoostInputParams;
import com.alibaba.alink.params.xgboost.XGBoostLearningTaskParams;
import com.alibaba.alink.params.xgboost.XGBoostTrainParams;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.AggregateOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@FeatureColsVectorColMutexRule
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "featureCols", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES}), @ParamSelectColumnSpec(name = "labelCol")})
/* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseXGBoostTrainBatchOp.class */
public abstract class BaseXGBoostTrainBatchOp<T extends BaseXGBoostTrainBatchOp<T>> extends BatchOperator<T> {
    private static final Logger LOG = LoggerFactory.getLogger(BaseXGBoostTrainBatchOp.class);

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseXGBoostTrainBatchOp$GenModel.class */
    public static class GenModel extends CompleteResultFunction {
        @Override // com.alibaba.alink.common.comqueue.CompleteResultFunction
        public List<Row> calc(ComContext comContext) {
            return (List) comContext.getObj(OptimVariable.model);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseXGBoostTrainBatchOp$InitTrackMapPartition.class */
    public static class InitTrackMapPartition extends RichMapPartitionFunction<Integer, Tuple2<String, String>> {
        private final XGBoostClassLoaderFactory xgBoostClassLoaderFactory;
        private final long trackerHandle;

        public InitTrackMapPartition(XGBoostClassLoaderFactory xGBoostClassLoaderFactory, long j) {
            this.xgBoostClassLoaderFactory = xGBoostClassLoaderFactory;
            this.trackerHandle = j;
        }

        public void mapPartition(Iterable<Integer> iterable, Collector<Tuple2<String, String>> collector) throws Exception {
            if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                Tracker initTracker = XGBoostClassLoaderFactory.create(this.xgBoostClassLoaderFactory).create().initTracker(getRuntimeContext().getNumberOfParallelSubtasks());
                if (!initTracker.start(0L)) {
                    throw new AkIllegalStateException("Tracker cannot be started");
                }
                IterTaskObjKeeper.put(this.trackerHandle, 0, initTracker);
                Iterator<Tuple2<String, String>> it = initTracker.getWorkerEnvs().iterator();
                while (it.hasNext()) {
                    collector.collect(it.next());
                }
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseXGBoostTrainBatchOp$InitTracker.class */
    public static class InitTracker extends ComputeFunction {
        private final XGBoostClassLoaderFactory xgBoostClassLoaderFactory;

        public InitTracker(XGBoostClassLoaderFactory xGBoostClassLoaderFactory) {
            this.xgBoostClassLoaderFactory = xGBoostClassLoaderFactory;
        }

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            if (comContext.getTaskId() == 0) {
                try {
                    Tracker initTracker = XGBoostClassLoaderFactory.create(this.xgBoostClassLoaderFactory).create().initTracker(comContext.getNumTask());
                    if (!initTracker.start(0L)) {
                        throw new AkIllegalStateException("Tracker cannot be started");
                    }
                    comContext.putObj("tracker", initTracker);
                    comContext.putObj("workerEnvs", initTracker.getWorkerEnvs());
                } catch (XGboostException e) {
                    throw new AkIllegalStateException("XGboost error.", e);
                }
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseXGBoostTrainBatchOp$RecycleTracker.class */
    public static class RecycleTracker extends ComputeFunction {
        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            if (comContext.getTaskId() == 0) {
                ((Tracker) comContext.getObj("tracker")).stop();
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseXGBoostTrainBatchOp$RecycleTrackerMapPartition.class */
    public static class RecycleTrackerMapPartition extends RichMapPartitionFunction<Tuple2<Boolean, Row>, byte[]> {
        private final long trackerHandle;

        public RecycleTrackerMapPartition(long j) {
            this.trackerHandle = j;
        }

        public void mapPartition(Iterable<Tuple2<Boolean, Row>> iterable, Collector<byte[]> collector) throws Exception {
            iterable.iterator().next();
            Tracker tracker = (Tracker) IterTaskObjKeeper.get(this.trackerHandle, 0);
            if (tracker != null) {
                tracker.stop();
            }
            IterTaskObjKeeper.clear(this.trackerHandle);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseXGBoostTrainBatchOp$SelectSampleCol.class */
    public static class SelectSampleCol implements MapFunction<Row, Row> {
        private final int[] featureColIndices;
        private final int labelColIndex;

        public SelectSampleCol(int[] iArr, int i) {
            this.featureColIndices = iArr;
            this.labelColIndex = i;
        }

        public Row map(Row row) {
            return Row.of(new Object[]{OutlierUtil.rowToDenseVector(row, this.featureColIndices, this.featureColIndices.length), row.getField(this.labelColIndex)});
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseXGBoostTrainBatchOp$XGBoostTrain.class */
    public static class XGBoostTrain extends ComputeFunction {
        private final Params params;
        private final int vectorColIndex;
        private final int labelColIndex;
        private final HasObjective.Objective objective;
        private final XGBoostClassLoaderFactory xgBoostClassLoaderFactory;

        public XGBoostTrain(Params params, int i, int i2, XGBoostClassLoaderFactory xGBoostClassLoaderFactory) {
            this.params = params;
            this.vectorColIndex = i;
            this.labelColIndex = i2;
            this.objective = (HasObjective.Objective) params.get(XGBoostLearningTaskParams.OBJECTIVE);
            this.xgBoostClassLoaderFactory = xGBoostClassLoaderFactory;
        }

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            int intValue = ((Integer) ((Tuple1) ((List) comContext.getObj(KMeansTrainBatchOp.VECTOR_SIZE)).get(0)).getField(0)).intValue();
            XGBoost create = XGBoostClassLoaderFactory.create(this.xgBoostClassLoaderFactory).create();
            ArrayList arrayList = new ArrayList((Collection) comContext.getObj("workerEnvs"));
            arrayList.add(Tuple2.of("DMLC_TASK_ID", String.valueOf(comContext.getTaskId())));
            try {
                create.init(arrayList);
                try {
                    try {
                        TemporaryClassLoaderContext of = TemporaryClassLoaderContext.of(this.xgBoostClassLoaderFactory.create());
                        Throwable th = null;
                        try {
                            try {
                                Booster train = BaseXGBoostTrainBatchOp.train(((Iterable) comContext.getObj("trainData")).iterator(), this.params, this.labelColIndex, this.vectorColIndex, intValue, create);
                                if (comContext.getTaskId() == 0) {
                                    comContext.putObj(OptimVariable.model, BaseXGBoostTrainBatchOp.generateModel(this.objective, this.params.m1495clone(), (List) comContext.getObj("labels"), intValue, train));
                                }
                                if (of != null) {
                                    if (0 != 0) {
                                        try {
                                            of.close();
                                        } catch (Throwable th2) {
                                            th.addSuppressed(th2);
                                        }
                                    } else {
                                        of.close();
                                    }
                                }
                                comContext.putObj("status", Collections.singletonList(true));
                            } finally {
                            }
                        } catch (Throwable th3) {
                            if (of != null) {
                                if (th != null) {
                                    try {
                                        of.close();
                                    } catch (Throwable th4) {
                                        th.addSuppressed(th4);
                                    }
                                } else {
                                    of.close();
                                }
                            }
                            throw th3;
                        }
                    } catch (XGboostException e) {
                        throw new AkIllegalStateException("XGBoost error.", e);
                    }
                } finally {
                    try {
                        create.shutdown();
                    } catch (XGboostException e2) {
                        BaseXGBoostTrainBatchOp.LOG.warn("Shutdown rabit error.", e2);
                    }
                }
            } catch (XGboostException e3) {
                throw new AkIllegalStateException("XGBoost init error.", e3);
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/BaseXGBoostTrainBatchOp$XGBoostTrainMapPartition.class */
    public static class XGBoostTrainMapPartition extends RichMapPartitionFunction<Row, Tuple2<Boolean, Row>> {
        private final Params params;
        private final int vectorColIndex;
        private final int labelColIndex;
        private final HasObjective.Objective objective;
        private final Row emptyModelRow;
        private final XGBoostClassLoaderFactory xgBoostClassLoaderFactory;
        private transient int vectorSize;
        private transient List<Tuple2<String, String>> workerEnvsList;
        private transient List<Object[]> labels;

        public XGBoostTrainMapPartition(Params params, int i, int i2, Row row, XGBoostClassLoaderFactory xGBoostClassLoaderFactory) {
            this.params = params;
            this.vectorColIndex = i;
            this.labelColIndex = i2;
            this.objective = (HasObjective.Objective) params.get(XGBoostLearningTaskParams.OBJECTIVE);
            this.emptyModelRow = row;
            this.xgBoostClassLoaderFactory = xGBoostClassLoaderFactory;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            this.vectorSize = ((Integer) getRuntimeContext().getBroadcastVariableWithInitializer(KMeansTrainBatchOp.VECTOR_SIZE, new BroadcastVariableInitializer<Tuple1<Integer>, Integer>() { // from class: com.alibaba.alink.operator.common.tree.BaseXGBoostTrainBatchOp.XGBoostTrainMapPartition.1
                public Integer initializeBroadcastVariable(Iterable<Tuple1<Integer>> iterable) {
                    return (Integer) iterable.iterator().next().f0;
                }

                /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                public /* bridge */ /* synthetic */ Object m603initializeBroadcastVariable(Iterable iterable) {
                    return initializeBroadcastVariable((Iterable<Tuple1<Integer>>) iterable);
                }
            })).intValue();
            this.workerEnvsList = (List) getRuntimeContext().getBroadcastVariableWithInitializer("workerEnvs", new BroadcastVariableInitializer<Tuple2<String, String>, List<Tuple2<String, String>>>() { // from class: com.alibaba.alink.operator.common.tree.BaseXGBoostTrainBatchOp.XGBoostTrainMapPartition.2
                public List<Tuple2<String, String>> initializeBroadcastVariable(Iterable<Tuple2<String, String>> iterable) {
                    ArrayList arrayList = new ArrayList();
                    Iterator<Tuple2<String, String>> it = iterable.iterator();
                    while (it.hasNext()) {
                        arrayList.add(it.next());
                    }
                    return arrayList;
                }

                /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                public /* bridge */ /* synthetic */ Object m604initializeBroadcastVariable(Iterable iterable) {
                    return initializeBroadcastVariable((Iterable<Tuple2<String, String>>) iterable);
                }
            });
            this.labels = (List) getRuntimeContext().getBroadcastVariableWithInitializer("labels", new BroadcastVariableInitializer<Object[], List<Object[]>>() { // from class: com.alibaba.alink.operator.common.tree.BaseXGBoostTrainBatchOp.XGBoostTrainMapPartition.3
                public List<Object[]> initializeBroadcastVariable(Iterable<Object[]> iterable) {
                    ArrayList arrayList = new ArrayList();
                    Iterator<Object[]> it = iterable.iterator();
                    while (it.hasNext()) {
                        arrayList.add(it.next());
                    }
                    return arrayList;
                }

                /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                public /* bridge */ /* synthetic */ Object m605initializeBroadcastVariable(Iterable iterable) {
                    return initializeBroadcastVariable((Iterable<Object[]>) iterable);
                }
            });
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Boolean, Row>> collector) throws Exception {
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            XGBoost create = XGBoostClassLoaderFactory.create(this.xgBoostClassLoaderFactory).create();
            ArrayList arrayList = new ArrayList(this.workerEnvsList);
            arrayList.add(Tuple2.of("DMLC_TASK_ID", String.valueOf(indexOfThisSubtask)));
            try {
                try {
                    create.init(arrayList);
                    try {
                        TemporaryClassLoaderContext of = TemporaryClassLoaderContext.of(this.xgBoostClassLoaderFactory.create());
                        Throwable th = null;
                        try {
                            try {
                                Booster train = BaseXGBoostTrainBatchOp.train(iterable.iterator(), this.params, this.labelColIndex, this.vectorColIndex, this.vectorSize, create);
                                if (indexOfThisSubtask == 0) {
                                    Iterator<Row> it = BaseXGBoostTrainBatchOp.generateModel(this.objective, this.params.m1495clone(), this.labels, this.vectorSize, train).iterator();
                                    while (it.hasNext()) {
                                        collector.collect(Tuple2.of(true, it.next()));
                                    }
                                }
                                if (of != null) {
                                    if (0 != 0) {
                                        try {
                                            of.close();
                                        } catch (Throwable th2) {
                                            th.addSuppressed(th2);
                                        }
                                    } else {
                                        of.close();
                                    }
                                }
                                collector.collect(Tuple2.of(false, this.emptyModelRow));
                            } finally {
                            }
                        } catch (Throwable th3) {
                            if (of != null) {
                                if (th != null) {
                                    try {
                                        of.close();
                                    } catch (Throwable th4) {
                                        th.addSuppressed(th4);
                                    }
                                } else {
                                    of.close();
                                }
                            }
                            throw th3;
                        }
                    } catch (XGboostException e) {
                        throw new AkIllegalStateException("XGBoost error", e);
                    }
                } catch (XGboostException e2) {
                    throw new AkIllegalStateException("XGBoost init error.", e2);
                }
            } finally {
                try {
                    create.shutdown();
                } catch (XGboostException e3) {
                    BaseXGBoostTrainBatchOp.LOG.warn("Shutdown rabit error.", e3);
                }
            }
        }
    }

    public BaseXGBoostTrainBatchOp() {
        this(new Params());
    }

    public BaseXGBoostTrainBatchOp(Params params) {
        super(params);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public T linkFrom(BatchOperator<?>... batchOperatorArr) {
        DataSet<Row> name;
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        Params m1495clone = getParams().m1495clone();
        HasObjective.Objective objective = (HasObjective.Objective) m1495clone.get(XGBoostLearningTaskParams.OBJECTIVE);
        XGBoostClassLoaderFactory xGBoostClassLoaderFactory = new XGBoostClassLoaderFactory((String) m1495clone.get(XGBoostTrainParams.PLUGIN_VERSION));
        boolean z = objective.equals(HasObjective.Objective.BINARY_LOGISTIC) || objective.equals(HasObjective.Objective.BINARY_LOGITRAW) || objective.equals(HasObjective.Objective.BINARY_HINGE) || objective.equals(HasObjective.Objective.MULTI_SOFTMAX) || objective.equals(HasObjective.Objective.MULTI_SOFTPROB);
        int findColIndex = TableUtil.findColIndex(checkAndGetFirst.getColNames(), (String) m1495clone.get(XGBoostInputParams.LABEL_COL));
        TypeInformation<?> typeInformation = Types.DOUBLE;
        if (objective.equals(HasObjective.Objective.BINARY_LOGISTIC) || objective.equals(HasObjective.Objective.BINARY_HINGE) || objective.equals(HasObjective.Objective.MULTI_SOFTMAX) || objective.equals(HasObjective.Objective.MULTI_SOFTPROB)) {
            typeInformation = TableUtil.findColType(checkAndGetFirst.getSchema(), (String) m1495clone.get(XGBoostInputParams.LABEL_COL));
        }
        DataSet<Object[]> generateLabels = Preprocessing.generateLabels(checkAndGetFirst, m1495clone, !z);
        BatchOperator<?> castLabel = Preprocessing.castLabel(checkAndGetFirst, m1495clone, generateLabels, !z);
        DataSet<Row> dataSet = m1495clone.contains(XGBoostInputParams.VECTOR_COL) ? Preprocessing.select(castLabel, (String) m1495clone.get(XGBoostInputParams.VECTOR_COL), (String) m1495clone.get(XGBoostInputParams.LABEL_COL)).getDataSet() : castLabel.getDataSet().map(new SelectSampleCol(TableUtil.findColIndicesWithAssertAndHint(castLabel.getColNames(), OutlierUtil.uniformFeatureColsDefaultAsAll(TableUtil.getNumericCols(castLabel.getSchema()), (String[]) m1495clone.get(XGBoostInputParams.FEATURE_COLS))), findColIndex)).returns(new RowTypeInfo(new TypeInformation[]{AlinkTypes.DENSE_VECTOR, TableUtil.findColType(castLabel.getSchema(), (String) m1495clone.get(XGBoostInputParams.LABEL_COL))}));
        AggregateOperator max = dataSet.map(new MapFunction<Row, Tuple1<Integer>>() { // from class: com.alibaba.alink.operator.common.tree.BaseXGBoostTrainBatchOp.1
            public Tuple1<Integer> map(Row row) throws Exception {
                Vector vector = VectorUtil.getVector(row.getField(0));
                if ((vector instanceof SparseVector) && vector.size() < 0) {
                    int[] indices = ((SparseVector) vector).getIndices();
                    return Tuple1.of(Integer.valueOf((indices == null || indices.length == 0) ? 0 : indices[indices.length - 1] + 1));
                }
                return Tuple1.of(Integer.valueOf(vector.size()));
            }
        }).name("Extract vector size").max(0);
        XGBoostDebugParams.RunningMode runningMode = (XGBoostDebugParams.RunningMode) m1495clone.get(XGBoostDebugParams.RUNNING_MODE);
        TableSchema modelSchema = new XGBoostModelDataConverter(typeInformation).getModelSchema();
        switch (runningMode) {
            case ICQ:
                name = new ComQueue().initWithPartitionedData("trainData", dataSet).initWithBroadcastData("labels", generateLabels).initWithBroadcastData(KMeansTrainBatchOp.VECTOR_SIZE, max).add(new InitTracker(xGBoostClassLoaderFactory)).add(new Bcast("workerEnvs", 0, Types.TUPLE(new TypeInformation[]{Types.STRING, Types.STRING}))).add(new XGBoostTrain(m1495clone, 0, 1, xGBoostClassLoaderFactory)).add(new Gather("status", 0, Types.BOOLEAN)).add(new RecycleTracker()).closeWith(new GenModel()).exec();
                break;
            case TRIVIAL:
                long newHandle = IterTaskObjKeeper.getNewHandle();
                Operator name2 = dataSet.mapPartition(new XGBoostTrainMapPartition(m1495clone, 0, 1, (Row) new RowTypeInfo(modelSchema.getFieldTypes()).createSerializer(new ExecutionConfig()).createInstance(), xGBoostClassLoaderFactory)).withBroadcastSet(generateLabels, "labels").withBroadcastSet(max, KMeansTrainBatchOp.VECTOR_SIZE).withBroadcastSet(MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromElements(new Integer[]{0}).partitionByHash(new KeySelector<Integer, Integer>() { // from class: com.alibaba.alink.operator.common.tree.BaseXGBoostTrainBatchOp.2
                    public Integer getKey(Integer num) {
                        return num;
                    }
                }).mapPartition(new InitTrackMapPartition(xGBoostClassLoaderFactory, newHandle)).name("Init tracker"), "workerEnvs").name("XGBoost train");
                name = name2.filter(new FilterFunction<Tuple2<Boolean, Row>>() { // from class: com.alibaba.alink.operator.common.tree.BaseXGBoostTrainBatchOp.5
                    public boolean filter(Tuple2<Boolean, Row> tuple2) throws Exception {
                        return ((Boolean) tuple2.f0).booleanValue();
                    }
                }).map(new MapFunction<Tuple2<Boolean, Row>, Row>() { // from class: com.alibaba.alink.operator.common.tree.BaseXGBoostTrainBatchOp.4
                    public Row map(Tuple2<Boolean, Row> tuple2) throws Exception {
                        return (Row) tuple2.f1;
                    }
                }).withBroadcastSet(name2.filter(new FilterFunction<Tuple2<Boolean, Row>>() { // from class: com.alibaba.alink.operator.common.tree.BaseXGBoostTrainBatchOp.3
                    public boolean filter(Tuple2<Boolean, Row> tuple2) throws Exception {
                        return !((Boolean) tuple2.f0).booleanValue();
                    }
                }).mapPartition(new RecycleTrackerMapPartition(newHandle)).name("Recycle tracker"), "recycleResult").name("Gen model");
                break;
            default:
                throw new AkIllegalArgumentException("Illegal running mode: " + runningMode);
        }
        setOutput(name, modelSchema);
        return this;
    }

    public static Booster train(Iterator<Row> it, Params params, int i, int i2, int i3, XGBoost xGBoost) throws XGboostException {
        return xGBoost.train(it, row -> {
            return row;
        }, row2 -> {
            Vector vector = VectorUtil.getVector(row2.getField(i2));
            if ((vector instanceof SparseVector) && vector.size() < 0) {
                ((SparseVector) vector).setSize(i3);
            }
            return Tuple2.of(vector, new float[]{((Number) AkPreconditions.checkNotNull(row2.getField(i))).floatValue()});
        }, params);
    }

    public static List<Row> generateModel(HasObjective.Objective objective, Params params, List<Object[]> list, int i, Booster booster) throws XGboostException {
        byte[] byteArray = booster.toByteArray();
        int length = byteArray.length;
        int i2 = length % 1024 == 0 ? length / 1024 : (length / 1024) + 1;
        Base64.Encoder encoder = Base64.getEncoder();
        XGBoostModelDataConverter xGBoostModelDataConverter = new XGBoostModelDataConverter();
        xGBoostModelDataConverter.meta = params.set((ParamInfo<ParamInfo<Integer>>) XGBoostModelDataConverter.XGBOOST_VECTOR_SIZE, (ParamInfo<Integer>) Integer.valueOf(i));
        xGBoostModelDataConverter.modelData = () -> {
            return new Iterator<String>() { // from class: com.alibaba.alink.operator.common.tree.BaseXGBoostTrainBatchOp.6
                int counter = 0;

                @Override // java.util.Iterator
                public boolean hasNext() {
                    return this.counter < i2;
                }

                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.Iterator
                public String next() {
                    Base64.Encoder encoder2 = encoder;
                    byte[] bArr = byteArray;
                    int i3 = this.counter * 1024;
                    int i4 = this.counter;
                    this.counter = i4 + 1;
                    return encoder2.encodeToString(ArrayUtils.subarray(bArr, i3, (i4 + 1) * 1024));
                }
            };
        };
        if (objective.equals(HasObjective.Objective.BINARY_LOGISTIC) || objective.equals(HasObjective.Objective.BINARY_HINGE) || objective.equals(HasObjective.Objective.MULTI_SOFTMAX) || objective.equals(HasObjective.Objective.MULTI_SOFTPROB)) {
            xGBoostModelDataConverter.labels = list.get(0);
        }
        final ArrayList arrayList = new ArrayList();
        xGBoostModelDataConverter.save(xGBoostModelDataConverter, new Collector<Row>() { // from class: com.alibaba.alink.operator.common.tree.BaseXGBoostTrainBatchOp.7
            public void collect(Row row) {
                arrayList.add(row);
            }

            public void close() {
            }
        });
        return arrayList;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ BatchOperator linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
