package com.alibaba.alink.operator.batch.outlier;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.dataproc.NumericalTypeCastMapper;
import com.alibaba.alink.operator.common.outlier.IForestDetector;
import com.alibaba.alink.operator.common.outlier.IForestModelDetector;
import com.alibaba.alink.operator.common.outlier.OutlierUtil;
import com.alibaba.alink.params.dataproc.HasTargetType;
import com.alibaba.alink.params.dataproc.NumericalTypeCastParams;
import com.alibaba.alink.params.outlier.IForestTrainParams;
import com.alibaba.alink.params.outlier.WithMultiVarParams;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
import org.apache.commons.lang3.SerializationUtils;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.ReduceOperator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@NameCn("IForest模型异常检测训练")
@NameEn("IForest model outlier")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.outlier.IForestModelOutlier")
/* loaded from: input_file:com/alibaba/alink/operator/batch/outlier/IForestModelOutlierTrainBatchOp.class */
public class IForestModelOutlierTrainBatchOp extends BatchOperator<IForestModelOutlierTrainBatchOp> implements IForestTrainParams<IForestModelOutlierTrainBatchOp> {
    private static final Logger LOG = LoggerFactory.getLogger(IForestModelOutlierTrainBatchOp.class);
    private static final double LOG2 = Math.log(2.0d);

    public IForestModelOutlierTrainBatchOp() {
        this(null);
    }

    public IForestModelOutlierTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public IForestModelOutlierTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        ReduceOperator fromElements;
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        final Params m1495clone = getParams().m1495clone();
        final int intValue = getNumTrees().intValue();
        final int intValue2 = getSubsamplingSize().intValue();
        final String[] colNames = checkAndGetFirst.getColNames();
        final TypeInformation<?>[] colTypes = checkAndGetFirst.getColTypes();
        DataSet<Row> dataSet = checkAndGetFirst.getDataSet();
        if (m1495clone.contains(WithMultiVarParams.VECTOR_COL)) {
            final int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getSchema(), (String) m1495clone.get(WithMultiVarParams.VECTOR_COL));
            fromElements = dataSet.map(new MapFunction<Row, Vector>() { // from class: com.alibaba.alink.operator.batch.outlier.IForestModelOutlierTrainBatchOp.3
                public Vector map(Row row) throws Exception {
                    return (Vector) row.getField(findColIndexWithAssertAndHint);
                }
            }).map(new MapFunction<Vector, Integer>() { // from class: com.alibaba.alink.operator.batch.outlier.IForestModelOutlierTrainBatchOp.2
                public Integer map(Vector vector) throws Exception {
                    return Integer.valueOf(OutlierUtil.vectorSize(vector));
                }
            }).reduce(new ReduceFunction<Integer>() { // from class: com.alibaba.alink.operator.batch.outlier.IForestModelOutlierTrainBatchOp.1
                public Integer reduce(Integer num, Integer num2) throws Exception {
                    return Integer.valueOf(Math.max(num.intValue(), num2.intValue()));
                }
            });
        } else {
            fromElements = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromElements(new Integer[]{0});
            m1495clone.set((ParamInfo<ParamInfo<String[]>>) WithMultiVarParams.FEATURE_COLS, (ParamInfo<String[]>) OutlierUtil.fillFeatures(checkAndGetFirst.getColNames(), m1495clone));
        }
        IterativeDataSet iterate = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromElements(new Tuple2[]{Tuple2.of(-1, new byte[0])}).iterate(intValue);
        SingleInputUdfOperator withBroadcastSet = dataSet.mapPartition(new RichMapPartitionFunction<Row, Tuple2<Integer, Row>>() { // from class: com.alibaba.alink.operator.batch.outlier.IForestModelOutlierTrainBatchOp.6
            public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Integer, Row>> collector) {
                int numberOfParallelSubtasks = getIterationRuntimeContext().getNumberOfParallelSubtasks();
                int superstepNumber = getIterationRuntimeContext().getSuperstepNumber();
                int i = (superstepNumber - 1) * numberOfParallelSubtasks;
                int min = Math.min(intValue - i, (superstepNumber * numberOfParallelSubtasks) - i);
                ThreadLocalRandom current = ThreadLocalRandom.current();
                ArrayList arrayList = new ArrayList(min);
                for (int i2 = 0; i2 < min; i2++) {
                    arrayList.add(new PriorityQueue(intValue2, Comparator.comparing(tuple3 -> {
                        return (Double) tuple3.f0;
                    })));
                }
                for (Row row : iterable) {
                    for (int i3 = 0; i3 < min; i3++) {
                        PriorityQueue priorityQueue = (PriorityQueue) arrayList.get(i3);
                        if (priorityQueue.size() < intValue2) {
                            priorityQueue.offer(Tuple3.of(Double.valueOf(current.nextDouble()), Integer.valueOf(i3), row));
                        } else {
                            Double valueOf = Double.valueOf(current.nextDouble());
                            if (valueOf.doubleValue() > ((Double) ((Tuple3) priorityQueue.element()).f0).doubleValue()) {
                                priorityQueue.poll();
                                priorityQueue.offer(Tuple3.of(valueOf, Integer.valueOf(i3), row));
                            }
                        }
                    }
                }
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    Iterator it2 = ((PriorityQueue) it.next()).iterator();
                    while (it2.hasNext()) {
                        Tuple3 tuple32 = (Tuple3) it2.next();
                        collector.collect(Tuple2.of(tuple32.f1, tuple32.f2));
                    }
                }
            }
        }).withBroadcastSet(iterate, "loop").partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.batch.outlier.IForestModelOutlierTrainBatchOp.5
            public int partition(Integer num, int i) {
                return num.intValue();
            }
        }, 0).mapPartition(new RichMapPartitionFunction<Tuple2<Integer, Row>, Tuple2<Integer, byte[]>>() { // from class: com.alibaba.alink.operator.batch.outlier.IForestModelOutlierTrainBatchOp.4
            private final List<List<IForestDetector.Node>> model = new ArrayList();

            /* JADX WARN: Type inference failed for: r2v14, types: [java.lang.Object[], java.io.Serializable] */
            public void mapPartition(Iterable<Tuple2<Integer, Row>> iterable, Collector<Tuple2<Integer, byte[]>> collector) throws Exception {
                ThreadLocalRandom current = ThreadLocalRandom.current();
                PriorityQueue priorityQueue = new PriorityQueue(intValue2, Comparator.comparing(tuple2 -> {
                    return (Double) tuple2.f0;
                }));
                for (Tuple2<Integer, Row> tuple22 : iterable) {
                    if (priorityQueue.size() < intValue2) {
                        priorityQueue.offer(Tuple2.of(Double.valueOf(current.nextDouble()), tuple22.f1));
                    } else {
                        Double valueOf = Double.valueOf(current.nextDouble());
                        if (valueOf.doubleValue() > ((Double) ((Tuple2) priorityQueue.element()).f0).doubleValue()) {
                            priorityQueue.poll();
                            priorityQueue.offer(Tuple2.of(valueOf, tuple22.f1));
                        }
                    }
                }
                MTable mTable = OutlierUtil.getMTable(new MTable((List<Row>) priorityQueue.stream().map(tuple23 -> {
                    return (Row) tuple23.f1;
                }).collect(Collectors.toList()), colNames, (TypeInformation<?>[]) colTypes), m1495clone);
                NumericalTypeCastMapper numericalTypeCastMapper = new NumericalTypeCastMapper(mTable.getSchema(), new Params().set((ParamInfo<ParamInfo<String[]>>) NumericalTypeCastParams.SELECTED_COLS, (ParamInfo<String[]>) mTable.getColNames()).set((ParamInfo<ParamInfo<HasTargetType.TargetType>>) NumericalTypeCastParams.TARGET_TYPE, (ParamInfo<HasTargetType.TargetType>) HasTargetType.TargetType.DOUBLE));
                int numRow = mTable.getNumRow();
                ArrayList arrayList = new ArrayList(numRow);
                for (int i = 0; i < numRow; i++) {
                    arrayList.add(numericalTypeCastMapper.map(mTable.getRow(i)));
                }
                MTable mTable2 = new MTable(arrayList, mTable.getSchemaStr());
                if (numRow > 0) {
                    this.model.add(new IForestDetector.IForestTrain(m1495clone).iTree(mTable2, (int) Math.ceil(Math.log(Math.min(arrayList.size(), intValue2)) / IForestModelOutlierTrainBatchOp.LOG2), current));
                }
                int numberOfParallelSubtasks = getIterationRuntimeContext().getNumberOfParallelSubtasks();
                int superstepNumber = getIterationRuntimeContext().getSuperstepNumber();
                if ((superstepNumber - 1) * numberOfParallelSubtasks >= intValue || superstepNumber * numberOfParallelSubtasks < intValue) {
                    collector.collect(Tuple2.of(-1, new byte[0]));
                    return;
                }
                if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                    m1495clone.set((ParamInfo<ParamInfo<Integer>>) IForestTrainParams.SUBSAMPLING_SIZE, (ParamInfo<Integer>) Integer.valueOf(Math.min(((Integer) m1495clone.get(IForestTrainParams.SUBSAMPLING_SIZE)).intValue(), numRow)));
                    if (m1495clone.contains(WithMultiVarParams.VECTOR_COL)) {
                        m1495clone.set((ParamInfo<ParamInfo<Integer>>) OutlierUtil.MAX_VECTOR_SIZE, (ParamInfo<Integer>) getRuntimeContext().getBroadcastVariableWithInitializer("maxVectorSize", new BroadcastVariableInitializer<Integer, Integer>() { // from class: com.alibaba.alink.operator.batch.outlier.IForestModelOutlierTrainBatchOp.4.1
                            public Integer initializeBroadcastVariable(Iterable<Integer> iterable2) {
                                return iterable2.iterator().next();
                            }

                            /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                            public /* bridge */ /* synthetic */ Object m282initializeBroadcastVariable(Iterable iterable2) {
                                return initializeBroadcastVariable((Iterable<Integer>) iterable2);
                            }
                        }));
                    }
                    collector.collect(Tuple2.of(0, SerializationUtils.serialize(m1495clone)));
                }
                Iterator<List<IForestDetector.Node>> it = this.model.iterator();
                while (it.hasNext()) {
                    collector.collect(Tuple2.of(1, SerializationUtils.serialize((Serializable) it.next().toArray(new IForestDetector.Node[0]))));
                }
            }
        }).withBroadcastSet(fromElements, "maxVectorSize");
        setOutput((DataSet<Row>) iterate.closeWith(withBroadcastSet.filter(new FilterFunction<Tuple2<Integer, byte[]>>() { // from class: com.alibaba.alink.operator.batch.outlier.IForestModelOutlierTrainBatchOp.7
            public boolean filter(Tuple2<Integer, byte[]> tuple2) {
                return ((Integer) tuple2.f0).intValue() >= 0;
            }
        }), withBroadcastSet.filter(new FilterFunction<Tuple2<Integer, byte[]>>() { // from class: com.alibaba.alink.operator.batch.outlier.IForestModelOutlierTrainBatchOp.8
            public boolean filter(Tuple2<Integer, byte[]> tuple2) {
                return ((Integer) tuple2.f0).intValue() < 0;
            }
        })).reduceGroup(new GroupReduceFunction<Tuple2<Integer, byte[]>, Row>() { // from class: com.alibaba.alink.operator.batch.outlier.IForestModelOutlierTrainBatchOp.9
            public void reduce(Iterable<Tuple2<Integer, byte[]>> iterable, Collector<Row> collector) {
                IForestModelDetector.IForestModel iForestModel = new IForestModelDetector.IForestModel();
                for (Tuple2<Integer, byte[]> tuple2 : iterable) {
                    if (((Integer) tuple2.f0).intValue() == 0) {
                        iForestModel.meta = (Params) SerializationUtils.deserialize((byte[]) tuple2.f1);
                    } else {
                        iForestModel.trees.add(new ArrayList(Arrays.asList((Object[]) SerializationUtils.deserialize((byte[]) tuple2.f1))));
                    }
                }
                new IForestModelDetector.IForestModelDataConverter().save(iForestModel, collector);
            }
        }), new IForestModelDetector.IForestModelDataConverter().getModelSchema());
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ IForestModelOutlierTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
