package com.alibaba.alink.operator.batch.outlier;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.outlier.OcsvmKernel;
import com.alibaba.alink.operator.common.outlier.OcsvmModelData;
import com.alibaba.alink.operator.common.outlier.OcsvmModelDataConverter;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.outlier.HaskernelType;
import com.alibaba.alink.params.outlier.OcsvmModelTrainParams;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;
import org.apache.flink.api.common.functions.AbstractRichFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@NameCn("One Class SVM异常检测模型训练")
@NameEn("Ocsvm outlier model train")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.outlier.OcsvmModelOutlier")
/* loaded from: input_file:com/alibaba/alink/operator/batch/outlier/OcsvmModelOutlierTrainBatchOp.class */
public final class OcsvmModelOutlierTrainBatchOp extends BatchOperator<OcsvmModelOutlierTrainBatchOp> implements OcsvmModelTrainParams<OcsvmModelOutlierTrainBatchOp> {
    private static final long serialVersionUID = 6727016080849088600L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/outlier/OcsvmModelOutlierTrainBatchOp$CalculateNum.class */
    public static class CalculateNum extends RichMapPartitionFunction<Row, Tuple3<Integer, Integer, Integer>> {
        private static final long serialVersionUID = -679835005763383100L;
        private final String[] featureColnames;
        private final String tensorColName;

        public CalculateNum(String[] strArr, String str) {
            this.featureColnames = strArr;
            this.tensorColName = str;
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Tuple3<Integer, Integer, Integer>> collector) throws Exception {
            int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
            int i = 0;
            int i2 = -1;
            for (Row row : iterable) {
                i++;
                if (this.tensorColName != null) {
                    Vector vector = VectorUtil.getVector(row.getField(0));
                    if (vector instanceof SparseVector) {
                        int[] indices = ((SparseVector) vector).getIndices();
                        i2 = Math.max(i2, (indices.length == 0 ? -1 : indices[indices.length - 1]) + 1);
                    } else {
                        i2 = vector.size();
                    }
                } else {
                    i2 = this.featureColnames.length;
                }
            }
            collector.collect(Tuple3.of(Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(numberOfParallelSubtasks)));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/outlier/OcsvmModelOutlierTrainBatchOp$SelectFeat.class */
    public static class SelectFeat implements MapFunction<Row, Row> {
        private static final long serialVersionUID = 331016784088329722L;
        private final int[] featureIndices;

        public SelectFeat(int[] iArr) {
            this.featureIndices = iArr;
        }

        public Row map(Row row) throws Exception {
            Row row2 = new Row(this.featureIndices.length);
            for (int i = 0; i < this.featureIndices.length; i++) {
                row2.setField(i, Double.valueOf(((Number) row.getField(this.featureIndices[i])).doubleValue()));
            }
            return row2;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/outlier/OcsvmModelOutlierTrainBatchOp$TrainSvm.class */
    public static class TrainSvm implements GroupReduceFunction<Tuple2<Integer, Row>, Tuple2<Double, OcsvmModelData.SvmModelData>> {
        private static final long serialVersionUID = -2783415250850319839L;
        private final Params param;
        private final String tensorColName;

        TrainSvm(Params params) {
            this.param = params;
            this.tensorColName = (String) params.get(OcsvmModelTrainParams.VECTOR_COL);
        }

        public void reduce(Iterable<Tuple2<Integer, Row>> iterable, Collector<Tuple2<Double, OcsvmModelData.SvmModelData>> collector) throws Exception {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            int i = 0;
            int i2 = 0;
            if (this.tensorColName != null) {
                for (Tuple2<Integer, Row> tuple2 : iterable) {
                    i2++;
                    arrayList.add(Double.valueOf(Criteria.INVALID_GAIN));
                    arrayList2.add(VectorUtil.getVector(((Row) tuple2.f1).getField(0)));
                }
            } else {
                for (Tuple2<Integer, Row> tuple22 : iterable) {
                    i2++;
                    arrayList.add(Double.valueOf(Criteria.INVALID_GAIN));
                    int arity = ((Row) tuple22.f1).getArity();
                    i = arity;
                    DenseVector denseVector = new DenseVector(arity);
                    for (int i3 = 0; i3 < arity; i3++) {
                        denseVector.set(i3, ((Number) ((Row) tuple22.f1).getField(i3)).doubleValue());
                    }
                    arrayList2.add(denseVector);
                }
            }
            if (i2 > 0) {
                Vector[] vectorArr = new Vector[arrayList.size()];
                for (int i4 = 0; i4 < arrayList.size(); i4++) {
                    vectorArr[i4] = (Vector) arrayList2.get(i4);
                }
                if (Math.abs(((Double) this.param.get(OcsvmModelTrainParams.GAMMA)).doubleValue()) < 1.0E-18d && i > 0) {
                    this.param.set((ParamInfo<ParamInfo<Double>>) OcsvmModelTrainParams.GAMMA, (ParamInfo<Double>) Double.valueOf(1.0d / i));
                }
                collector.collect(Tuple2.of(Double.valueOf(1.0d / i), OcsvmKernel.svmTrain(vectorArr, this.param)));
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/outlier/OcsvmModelOutlierTrainBatchOp$Transform.class */
    public static class Transform extends AbstractRichFunction implements MapPartitionFunction<Tuple2<Double, OcsvmModelData.SvmModelData>, Row> {
        private static final long serialVersionUID = -8875298030671722207L;
        private final String[] featureColNames;
        private int baggingNumber;
        private final HaskernelType.KernelType kernelType;
        private final int degree;
        private double gamma;
        private final double coef0;
        private final String vectorCol;

        public Transform(Params params) {
            this.featureColNames = (String[]) params.get(OcsvmModelTrainParams.FEATURE_COLS);
            this.kernelType = (HaskernelType.KernelType) params.get(OcsvmModelTrainParams.KERNEL_TYPE);
            this.degree = ((Integer) params.get(OcsvmModelTrainParams.DEGREE)).intValue();
            this.coef0 = ((Double) params.get(OcsvmModelTrainParams.COEF0)).doubleValue();
            this.vectorCol = (String) params.get(OcsvmModelTrainParams.VECTOR_COL);
        }

        public void open(Configuration configuration) throws Exception {
            this.baggingNumber = ((Integer) getRuntimeContext().getBroadcastVariable("bNumber").get(0)).intValue();
        }

        public void mapPartition(Iterable<Tuple2<Double, OcsvmModelData.SvmModelData>> iterable, Collector<Row> collector) throws Exception {
            ArrayList arrayList = new ArrayList();
            int i = 0;
            for (Tuple2<Double, OcsvmModelData.SvmModelData> tuple2 : iterable) {
                arrayList.add(tuple2.f1);
                this.gamma = ((Double) tuple2.f0).doubleValue();
                i++;
            }
            OcsvmModelData.SvmModelData[] svmModelDataArr = new OcsvmModelData.SvmModelData[i];
            for (int i2 = 0; i2 < i; i2++) {
                svmModelDataArr[i2] = (OcsvmModelData.SvmModelData) arrayList.get(i2);
            }
            if (svmModelDataArr.length != 0) {
                OcsvmModelData ocsvmModelData = new OcsvmModelData();
                ocsvmModelData.models = svmModelDataArr;
                ocsvmModelData.featureColNames = this.featureColNames;
                ocsvmModelData.baggingNumber = this.baggingNumber;
                ocsvmModelData.kernelType = this.kernelType;
                ocsvmModelData.coef0 = this.coef0;
                ocsvmModelData.degree = this.degree;
                ocsvmModelData.gamma = this.gamma;
                ocsvmModelData.vectorCol = this.vectorCol;
                new OcsvmModelDataConverter().save(ocsvmModelData, collector);
            }
        }
    }

    public OcsvmModelOutlierTrainBatchOp() {
        super(new Params());
    }

    public OcsvmModelOutlierTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public OcsvmModelOutlierTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        MapOperator dataSet;
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String[] featureCols = getFeatureCols();
        String vectorCol = getVectorCol();
        if ("".equals(vectorCol)) {
            vectorCol = null;
        }
        if (featureCols != null && featureCols.length == 0) {
            featureCols = null;
        }
        final double doubleValue = getNu().doubleValue();
        if (vectorCol != null || featureCols == null) {
            dataSet = checkAndGetFirst.select(vectorCol).getDataSet();
        } else {
            int[] iArr = new int[featureCols.length];
            for (int i = 0; i < featureCols.length; i++) {
                iArr[i] = TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getColNames(), featureCols[i]);
            }
            dataSet = checkAndGetFirst.getDataSet().map(new SelectFeat(iArr));
        }
        MapOperator map = dataSet.mapPartition(new CalculateNum(featureCols, vectorCol)).reduce(new ReduceFunction<Tuple3<Integer, Integer, Integer>>() { // from class: com.alibaba.alink.operator.batch.outlier.OcsvmModelOutlierTrainBatchOp.2
            private static final long serialVersionUID = 2307240714136503892L;

            public Tuple3<Integer, Integer, Integer> reduce(Tuple3<Integer, Integer, Integer> tuple3, Tuple3<Integer, Integer, Integer> tuple32) {
                return Tuple3.of(Integer.valueOf(((Integer) tuple3.f0).intValue() + ((Integer) tuple32.f0).intValue()), Integer.valueOf(Math.max(((Integer) tuple3.f1).intValue(), ((Integer) tuple32.f1).intValue())), tuple32.f2);
            }
        }).map(new RichMapFunction<Tuple3<Integer, Integer, Integer>, Integer>() { // from class: com.alibaba.alink.operator.batch.outlier.OcsvmModelOutlierTrainBatchOp.1
            public Integer map(Tuple3<Integer, Integer, Integer> tuple3) {
                return ((Integer) tuple3.f1).intValue() < 10 ? Integer.valueOf(Math.max(((Integer) tuple3.f2).intValue(), (int) Math.ceil(((((Integer) tuple3.f0).intValue() * ((Integer) tuple3.f1).intValue()) * doubleValue) / 20000.0d))) : (((Integer) tuple3.f1).intValue() >= 100 || ((Integer) tuple3.f1).intValue() <= 10) ? Integer.valueOf(Math.max(((Integer) tuple3.f2).intValue(), (int) Math.ceil((((Integer) tuple3.f0).intValue() * doubleValue) / 1000.0d))) : Integer.valueOf(Math.max(((Integer) tuple3.f2).intValue(), (int) Math.ceil(((((Integer) tuple3.f0).intValue() * ((Integer) tuple3.f1).intValue()) * doubleValue) / 100000.0d)));
            }
        });
        setOutput((DataSet<Row>) dataSet.mapPartition(new RichMapPartitionFunction<Row, Tuple2<Integer, Row>>() { // from class: com.alibaba.alink.operator.batch.outlier.OcsvmModelOutlierTrainBatchOp.3
            private int bNumber;
            private final Random rand = new Random();

            public void open(Configuration configuration) {
                this.bNumber = ((Integer) getRuntimeContext().getBroadcastVariable("bNumber").get(0)).intValue();
            }

            public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Integer, Row>> collector) {
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    collector.collect(Tuple2.of(Integer.valueOf(this.rand.nextInt(this.bNumber)), it.next()));
                }
            }
        }).withBroadcastSet(map, "bNumber").groupBy(new int[]{0}).reduceGroup(new TrainSvm(getParams())).mapPartition(new Transform(getParams())).withBroadcastSet(map, "bNumber").setParallelism(1), new OcsvmModelDataConverter().getModelSchema());
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ OcsvmModelOutlierTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
