package com.alibaba.alink.operator.batch.recommendation;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.DataSetWrapperBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.recommendation.Zipped2KObjectBatchOp;
import com.alibaba.alink.params.recommendation.LeaveTopKObjectOutParams;
import com.alibaba.alink.params.recommendation.Zipped2KObjectParams;
import java.util.ArrayList;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.FlatMapOperator;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(PortType.DATA)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "groupCol"), @ParamSelectColumnSpec(name = "objectCol"), @ParamSelectColumnSpec(name = "rateCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES})})
@NameCn("推荐结果TopK采样处理")
@NameEn("Leave TopK Object Out")
/* loaded from: input_file:com/alibaba/alink/operator/batch/recommendation/LeaveTopKObjectOutBatchOp.class */
public class LeaveTopKObjectOutBatchOp extends BatchOperator<LeaveTopKObjectOutBatchOp> implements LeaveTopKObjectOutParams<LeaveTopKObjectOutBatchOp> {
    private static final long serialVersionUID = -2703896174042986392L;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/recommendation/LeaveTopKObjectOutBatchOp$Split.class */
    public static class Split implements GroupReduceFunction<Row, Tuple2<Boolean, Row>> {
        private static final long serialVersionUID = 5727094306089631645L;
        private final Double testFraction;
        private final Integer testK;
        private final double threshold;
        private final int rateIdx;

        public Split(Double d, Integer num, double d2, int i) {
            this.testFraction = d;
            this.testK = num;
            this.threshold = d2;
            this.rateIdx = i;
        }

        public void reduce(Iterable<Row> iterable, Collector<Tuple2<Boolean, Row>> collector) {
            ArrayList arrayList = new ArrayList();
            arrayList.getClass();
            iterable.forEach((v1) -> {
                r1.add(v1);
            });
            arrayList.sort((row, row2) -> {
                return Double.compare(((Number) row2.getField(this.rateIdx)).doubleValue(), ((Number) row.getField(this.rateIdx)).doubleValue());
            });
            int min = Math.min((int) Math.ceil(arrayList.size() * this.testFraction.doubleValue()), this.testK.intValue());
            int i = 0;
            while (i < arrayList.size()) {
                Row row3 = (Row) arrayList.get(i);
                collector.collect(Tuple2.of(Boolean.valueOf(i >= min || ((Number) row3.getField(this.rateIdx)).doubleValue() < this.threshold), row3));
                i++;
            }
        }
    }

    public LeaveTopKObjectOutBatchOp() {
        this(new Params());
    }

    public LeaveTopKObjectOutBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public LeaveTopKObjectOutBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        GroupReduceOperator reduceGroup = checkAndGetFirst.getDataSet().groupBy(new int[]{TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getSchema(), getGroupCol())}).reduceGroup(new Split(getFraction(), getK(), getRateThreshold().doubleValue(), TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getSchema(), getRateCol())));
        FlatMapOperator flatMap = reduceGroup.flatMap(new FlatMapFunction<Tuple2<Boolean, Row>, Row>() { // from class: com.alibaba.alink.operator.batch.recommendation.LeaveTopKObjectOutBatchOp.1
            private static final long serialVersionUID = -2766287446658278413L;

            public void flatMap(Tuple2<Boolean, Row> tuple2, Collector<Row> collector) {
                if (((Boolean) tuple2.f0).booleanValue()) {
                    collector.collect(tuple2.f1);
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple2<Boolean, Row>) obj, (Collector<Row>) collector);
            }
        });
        Zipped2KObjectBatchOp linkFrom = new Zipped2KObjectBatchOp(getParams().set((ParamInfo<ParamInfo<String[]>>) Zipped2KObjectParams.INFO_COLS, (ParamInfo<String[]>) new String[]{getRateCol()})).linkFrom((BatchOperator) new DataSetWrapperBatchOp(reduceGroup.flatMap(new FlatMapFunction<Tuple2<Boolean, Row>, Row>() { // from class: com.alibaba.alink.operator.batch.recommendation.LeaveTopKObjectOutBatchOp.2
            private static final long serialVersionUID = 3051286291048876503L;

            public void flatMap(Tuple2<Boolean, Row> tuple2, Collector<Row> collector) {
                if (((Boolean) tuple2.f0).booleanValue()) {
                    return;
                }
                collector.collect(tuple2.f1);
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple2<Boolean, Row>) obj, (Collector<Row>) collector);
            }
        }), checkAndGetFirst.getColNames(), checkAndGetFirst.getColTypes()).setMLEnvironmentId(getMLEnvironmentId()));
        setOutput(linkFrom.getDataSet(), linkFrom.getSchema());
        setSideOutputTables(new Table[]{DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) flatMap, checkAndGetFirst.getSchema())});
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ LeaveTopKObjectOutBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
