package com.alibaba.alink.operator.batch.huge.line;

import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.io.directreader.DefaultDistributedInfo;
import com.alibaba.alink.operator.common.aps.ApsContext;
import com.alibaba.alink.operator.common.aps.ApsFuncIndex4Pull;
import com.alibaba.alink.params.graph.LineParams;
import com.alibaba.alink.params.nlp.HasNegative;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.flink.ml.api.misc.param.Params;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/batch/huge/line/ApsIndexFunc4PullLine.class */
public class ApsIndexFunc4PullLine extends ApsFuncIndex4Pull<Number[]> {
    private static final Logger LOG = LoggerFactory.getLogger(ApsIndexFunc4PullLine.class);
    private static final long serialVersionUID = -3540228966146033262L;
    private final Params params;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/huge/line/ApsIndexFunc4PullLine$NegSampleRunner.class */
    protected static class NegSampleRunner extends Thread {
        private final Long[] nsPool;
        List<Number[]> input;
        Set<Long> output;
        int seed;
        int negaTime;
        double sampleRatioPerPartition;

        NegSampleRunner(Long[] lArr, int i, List<Number[]> list, Set<Long> set, int i2, double d) {
            this.nsPool = lArr;
            this.input = list;
            this.output = set;
            this.seed = i;
            this.negaTime = i2;
            this.sampleRatioPerPartition = d;
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            new LinePullAndTrainOperation(this.negaTime, this.nsPool).getIndexes(this.seed, this.sampleRatioPerPartition, this.input, this.output);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ApsIndexFunc4PullLine(Params params) {
        this.params = params;
    }

    @Override // com.alibaba.alink.operator.common.aps.ApsFuncIndex4Pull
    protected Set<Long> requestIndex(List<Number[]> list) throws Exception {
        long nanoTime = System.nanoTime();
        LOG.info("taskId: {}, localId: {}", Integer.valueOf(getPatitionId()), Integer.valueOf(getRuntimeContext().getIndexOfThisSubtask()));
        LOG.info("taskId: {}, negInputSize: {}", Integer.valueOf(getPatitionId()), Integer.valueOf(list.size()));
        if (null == this.contextParams) {
            throw new AkUnclassifiedErrorException("Aps server meets RuntimeException when pulling index");
        }
        Long[] longArray = this.contextParams.getLongArray("negBound");
        int intValue = ((Long[]) this.contextParams.get(ApsContext.SEEDS))[getRuntimeContext().getIndexOfThisSubtask()].intValue();
        int intValue2 = ((Integer) this.params.get(HasNegative.NEGATIVE)).intValue();
        int intValue3 = this.params.getIntegerOrDefault("threadNum", 8).intValue();
        double doubleValue = ((Double) this.params.get(LineParams.SAMPLE_RATIO_PER_PARTITION)).doubleValue();
        Thread[] threadArr = new Thread[intValue3];
        Set[] setArr = new Set[intValue3];
        DefaultDistributedInfo defaultDistributedInfo = new DefaultDistributedInfo();
        for (int i = 0; i < intValue3; i++) {
            int startPos = (int) defaultDistributedInfo.startPos(i, intValue3, list.size());
            int localRowCnt = ((int) defaultDistributedInfo.localRowCnt(i, intValue3, list.size())) + startPos;
            LOG.info("taskId: {}, negStart: {}, end: {}", new Object[]{Integer.valueOf(getPatitionId()), Integer.valueOf(startPos), Integer.valueOf(localRowCnt)});
            setArr[i] = new HashSet();
            threadArr[i] = new NegSampleRunner(longArray, intValue + i, list.subList(startPos, localRowCnt), setArr[i], intValue2, doubleValue);
            threadArr[i].start();
        }
        for (int i2 = 0; i2 < intValue3; i2++) {
            threadArr[i2].join();
        }
        HashSet hashSet = new HashSet();
        for (int i3 = 0; i3 < intValue3; i3++) {
            hashSet.addAll(setArr[i3]);
        }
        LOG.info("taskId: {}, negOutputSize: {}", Integer.valueOf(getPatitionId()), Integer.valueOf(hashSet.size()));
        LOG.info("taskId: {}, negTime: {}", Integer.valueOf(getPatitionId()), Double.valueOf((System.nanoTime() - nanoTime) / 1000000.0d));
        return hashSet;
    }
}
