package com.alibaba.alink.operator.batch.recommendation;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp;
import com.alibaba.alink.params.recommendation.NegativeItemSamplingParams;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.DATA)})
@NameCn("推荐负采样")
@NameEn("Negative Item Sampling")
/* loaded from: input_file:com/alibaba/alink/operator/batch/recommendation/NegativeItemSamplingBatchOp.class */
public final class NegativeItemSamplingBatchOp extends BatchOperator<NegativeItemSamplingBatchOp> implements NegativeItemSamplingParams<NegativeItemSamplingBatchOp> {
    private static final long serialVersionUID = 1296665548360617576L;

    public NegativeItemSamplingBatchOp() {
        this(new Params());
    }

    public NegativeItemSamplingBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public NegativeItemSamplingBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> batchOperator = batchOperatorArr[0];
        setMLEnvironmentId(batchOperator.getMLEnvironmentId());
        AkPreconditions.checkArgument(batchOperator.getColNames().length == 2, (ExceptionWithErrorCode) new AkIllegalArgumentException("num of user item pair column is not equal 2."));
        negativeSampling(batchOperator, batchOperator.select(batchOperator.getColNames()[1]).distinct());
        setOutputTable(((ShuffleBatchOp) link(new ShuffleBatchOp())).getOutputTable());
        return this;
    }

    private static DataSet<Tuple2<Object, Object>> getUserItemDataSet(BatchOperator<?> batchOperator) {
        return batchOperator.getDataSet().map(new MapFunction<Row, Tuple2<Object, Object>>() { // from class: com.alibaba.alink.operator.batch.recommendation.NegativeItemSamplingBatchOp.1
            private static final long serialVersionUID = -2086770134528760473L;

            public Tuple2<Object, Object> map(Row row) {
                return Tuple2.of(row.getField(0), row.getField(1));
            }
        });
    }

    private void negativeSampling(BatchOperator<?> batchOperator, BatchOperator<?> batchOperator2) {
        final int intValue = getSamplingFactor().intValue();
        AkPreconditions.checkArgument(batchOperator.getColNames().length == 2, (ExceptionWithErrorCode) new AkIllegalArgumentException("num of data column is not equal 2."));
        AkPreconditions.checkArgument(batchOperator2.getColNames().length == 1, (ExceptionWithErrorCode) new AkIllegalDataException("num of distinctItems column is not equal 1."));
        setOutput((DataSet<Row>) getUserItemDataSet(batchOperator).map(new MapFunction<Tuple2<Object, Object>, Tuple3<String, Object, Object>>() { // from class: com.alibaba.alink.operator.batch.recommendation.NegativeItemSamplingBatchOp.4
            private static final long serialVersionUID = -6957327460225823558L;

            public Tuple3<String, Object, Object> map(Tuple2<Object, Object> tuple2) {
                return Tuple3.of(String.valueOf(tuple2.f0), tuple2.f0, tuple2.f1);
            }
        }).groupBy(new int[]{0}).reduceGroup(new RichGroupReduceFunction<Tuple3<String, Object, Object>, Tuple3<Object, Object, Long>>() { // from class: com.alibaba.alink.operator.batch.recommendation.NegativeItemSamplingBatchOp.3
            private static final long serialVersionUID = 306722066512456784L;
            transient List<Long> candidates;
            transient Random random;

            public void open(Configuration configuration) {
                this.random = new Random(getRuntimeContext().getIndexOfThisSubtask());
                this.candidates = getRuntimeContext().getBroadcastVariable("items");
            }

            public void reduce(Iterable<Tuple3<String, Object, Object>> iterable, Collector<Tuple3<Object, Object, Long>> collector) {
                HashSet hashSet = new HashSet();
                Object obj = null;
                long j = 0;
                for (Tuple3<String, Object, Object> tuple3 : iterable) {
                    obj = tuple3.f1;
                    hashSet.add(tuple3.f2);
                    j++;
                    collector.collect(Tuple3.of(obj, tuple3.f2, 1L));
                }
                for (int i = 0; i < j * intValue; i++) {
                    int i2 = 0;
                    while (true) {
                        if (i2 < 32) {
                            Long l = this.candidates.get(this.random.nextInt(this.candidates.size()));
                            if (!hashSet.contains(l)) {
                                collector.collect(Tuple3.of(obj, l, 0L));
                                break;
                            }
                            i2++;
                        }
                    }
                }
            }
        }).withBroadcastSet(batchOperator2.getDataSet().map(new MapFunction<Row, Object>() { // from class: com.alibaba.alink.operator.batch.recommendation.NegativeItemSamplingBatchOp.2
            private static final long serialVersionUID = -8648184004287735175L;

            public Object map(Row row) {
                return row.getField(0);
            }
        }), "items").name("negative_sampling").map(new MapFunction<Tuple3<Object, Object, Long>, Row>() { // from class: com.alibaba.alink.operator.batch.recommendation.NegativeItemSamplingBatchOp.5
            private static final long serialVersionUID = -9124354562578678385L;

            public Row map(Tuple3<Object, Object, Long> tuple3) {
                return Row.of(new Object[]{tuple3.f0, tuple3.f1, tuple3.f2});
            }
        }), new TableSchema((String[]) ArrayUtils.add(batchOperator.getColNames(), "label"), (TypeInformation[]) ArrayUtils.add(batchOperator.getColTypes(), Types.LONG)));
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ NegativeItemSamplingBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
