package com.alibaba.alink.operator.batch.huge.word2vec;

import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.io.directreader.DefaultDistributedInfo;
import com.alibaba.alink.operator.common.aps.ApsContext;
import com.alibaba.alink.operator.common.aps.ApsFuncIndex4Pull;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.commons.lang.ArrayUtils;
import org.apache.flink.ml.api.misc.param.Params;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/batch/huge/word2vec/ApsFuncIndex4PullW2V.class */
public class ApsFuncIndex4PullW2V extends ApsFuncIndex4Pull<int[]> {
    private static final Logger LOG = LoggerFactory.getLogger(ApsFuncIndex4PullW2V.class);
    private static final long serialVersionUID = -341644994350955260L;
    private final Params params;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/huge/word2vec/ApsFuncIndex4PullW2V$NegSampleRunner.class */
    public static class NegSampleRunner extends Thread {
        Params params;
        List<int[]> input;
        Set<Long> output;
        long seed;
        private int vocSize;
        private Long[] nsPool;
        private Object[] groupIdxObjs;
        private long[] groupIdxStarts;

        public NegSampleRunner(int i, Long[] lArr, Params params, long j, List<int[]> list, Set<Long> set, Object[] objArr, long[] jArr) {
            this.vocSize = i;
            this.nsPool = lArr;
            this.params = params;
            this.input = list;
            this.output = set;
            this.seed = j;
            this.groupIdxObjs = objArr;
            this.groupIdxStarts = jArr;
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            new Word2Vec(this.params, this.vocSize, this.nsPool, this.groupIdxObjs, this.groupIdxStarts).getIndexes(this.seed, this.input, this.output);
        }
    }

    public ApsFuncIndex4PullW2V(Params params) {
        this.params = params;
    }

    @Override // com.alibaba.alink.operator.common.aps.ApsFuncIndex4Pull
    protected Set<Long> requestIndex(List<int[]> list) throws Exception {
        long nanoTime = System.nanoTime();
        LOG.info("taskId: {}, localId: {}", Integer.valueOf(getPatitionId()), Integer.valueOf(getRuntimeContext().getIndexOfThisSubtask()));
        LOG.info("taskId: {}, negInputSize: {}", Integer.valueOf(getPatitionId()), Integer.valueOf(list.size()));
        if (null == this.contextParams) {
            throw new AkUnclassifiedErrorException("ApsFunction meets RuntimeException");
        }
        Long[] lArr = (Long[]) this.contextParams.get(ApsContext.SEEDS);
        Long[] longArray = this.contextParams.getLongArray("negBound");
        long[] primitive = this.params.getBoolOrDefault("metapathMode", false).booleanValue() ? ArrayUtils.toPrimitive(this.contextParams.getLongArray("groupIdxes")) : null;
        int intValue = this.contextParams.getLong("vocSize").intValue();
        long longValue = lArr[getPatitionId()].longValue();
        int intValue2 = this.params.getIntegerOrDefault("threadNum", 8).intValue();
        Thread[] threadArr = new Thread[intValue2];
        Set[] setArr = new Set[intValue2];
        DefaultDistributedInfo defaultDistributedInfo = new DefaultDistributedInfo();
        for (int i = 0; i < intValue2; i++) {
            int startPos = (int) defaultDistributedInfo.startPos(i, intValue2, list.size());
            int localRowCnt = ((int) defaultDistributedInfo.localRowCnt(i, intValue2, list.size())) + startPos;
            LOG.info("taskId: {}, negStart: {}, end: {}", new Object[]{Integer.valueOf(getPatitionId()), Integer.valueOf(startPos), Integer.valueOf(localRowCnt)});
            setArr[i] = new HashSet();
            threadArr[i] = new NegSampleRunner(intValue, longArray, this.params, longValue + i, list.subList(startPos, localRowCnt), setArr[i], null, primitive);
            threadArr[i].start();
        }
        for (int i2 = 0; i2 < intValue2; i2++) {
            threadArr[i2].join();
        }
        HashSet hashSet = new HashSet();
        for (int i3 = 0; i3 < intValue2; i3++) {
            hashSet.addAll(setArr[i3]);
        }
        LOG.info("taskId: {}, negOutputSize: {}", Integer.valueOf(getPatitionId()), Integer.valueOf(hashSet.size()));
        LOG.info("taskId: {}, negTime: {}", Integer.valueOf(getPatitionId()), Double.valueOf((System.nanoTime() - nanoTime) / 1000000.0d));
        return hashSet;
    }
}
