package com.alibaba.alink.operator.batch.huge.word2vec;

import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.io.directreader.DefaultDistributedInfo;
import com.alibaba.alink.operator.common.aps.ApsFuncTrain;
import com.alibaba.alink.params.shared.HasVectorSizeDefaultAs100;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang.ArrayUtils;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/batch/huge/word2vec/ApsFuncTrainW2V.class */
public class ApsFuncTrainW2V extends ApsFuncTrain<int[], float[]> {
    private static final Logger LOG = LoggerFactory.getLogger(ApsFuncTrainW2V.class);
    private static final long serialVersionUID = -6458331525690591343L;
    private final Params params;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/huge/word2vec/ApsFuncTrainW2V$TrainSubSet.class */
    public static class TrainSubSet {
        private int vocSize;
        private Long[] nsPool;
        private Object[] groupIdxObjs;
        private long[] groupIdxStarts;
        private Params params;
        private int taskId;

        public TrainSubSet(int i, Long[] lArr, Params params, int i2, Object[] objArr, long[] jArr) {
            this.vocSize = i;
            this.nsPool = lArr;
            this.params = params;
            this.taskId = i2;
            this.groupIdxObjs = objArr;
            this.groupIdxStarts = jArr;
        }

        public void train(long j, List<int[]> list, float[] fArr, Map<Long, Integer> map) throws InterruptedException {
            int intValue = this.params.getIntegerOrDefault("threadNum", 8).intValue();
            Thread[] threadArr = new Thread[intValue];
            DefaultDistributedInfo defaultDistributedInfo = new DefaultDistributedInfo();
            for (int i = 0; i < intValue; i++) {
                int startPos = (int) defaultDistributedInfo.startPos(i, intValue, list.size());
                int localRowCnt = ((int) defaultDistributedInfo.localRowCnt(i, intValue, list.size())) + startPos;
                ApsFuncTrainW2V.LOG.info("taskId: {}, trainStart: {}, end: {}", new Object[]{Integer.valueOf(this.taskId), Integer.valueOf(startPos), Integer.valueOf(localRowCnt)});
                threadArr[i] = new TrainSubSetRunner(this.vocSize, this.nsPool, this.params, j + i, list.subList(startPos, localRowCnt), fArr, map, this.groupIdxObjs, this.groupIdxStarts);
                threadArr[i].start();
            }
            for (int i2 = 0; i2 < intValue; i2++) {
                threadArr[i2].join();
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/huge/word2vec/ApsFuncTrainW2V$TrainSubSetRunner.class */
    public static class TrainSubSetRunner extends Thread {
        Params params;
        List<int[]> input;
        long seed;
        Map<Long, Integer> mapFeatureId2Local;
        float[] buffer;
        private int vocSize;
        private Long[] nsPool;
        private Object[] groupIdxObjs;
        private long[] groupIdxStarts;

        public TrainSubSetRunner(int i, Long[] lArr, Params params, long j, List<int[]> list, float[] fArr, Map<Long, Integer> map, Object[] objArr, long[] jArr) {
            this.vocSize = i;
            this.nsPool = lArr;
            this.params = params;
            this.input = list;
            this.buffer = fArr;
            this.mapFeatureId2Local = map;
            this.groupIdxObjs = objArr;
            this.groupIdxStarts = jArr;
            this.seed = j;
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            new Word2Vec(this.params, this.vocSize, this.nsPool, this.groupIdxObjs, this.groupIdxStarts).train3(this.seed, this.input, this.buffer, this.mapFeatureId2Local);
        }
    }

    public ApsFuncTrainW2V(Params params) {
        this.params = params;
    }

    @Override // com.alibaba.alink.operator.common.aps.ApsFuncTrain
    protected List<Tuple2<Long, float[]>> train(List<Tuple2<Long, float[]>> list, Map<Long, Integer> map, List<int[]> list2) throws Exception {
        if (null == this.contextParams) {
            throw new AkUnclassifiedErrorException("ApsFunction meets RuntimeException");
        }
        long nanoTime = System.nanoTime();
        LOG.info("taskId: {}, localId: {}", Integer.valueOf(getPatitionId()), Integer.valueOf(getRuntimeContext().getIndexOfThisSubtask()));
        Long[] longArray = this.contextParams.getLongArray("seeds");
        Long[] longArray2 = this.contextParams.getLongArray("negBound");
        long[] primitive = this.params.getBoolOrDefault("metapathMode", false).booleanValue() ? ArrayUtils.toPrimitive(this.contextParams.getLongArray("groupIdxes")) : null;
        int intValue = this.contextParams.getLong("vocSize").intValue();
        long longValue = longArray[getPatitionId()].longValue();
        LOG.info("taskId: {}, trainDataSize: {}", Integer.valueOf(getPatitionId()), Integer.valueOf(list2.size()));
        int intValue2 = ((Integer) this.params.get(HasVectorSizeDefaultAs100.VECTOR_SIZE)).intValue();
        float[] fArr = new float[list.size() * intValue2];
        int i = 0;
        int i2 = 0;
        for (Tuple2<Long, float[]> tuple2 : list) {
            if (((Long) tuple2.f0).longValue() < intValue) {
                i2++;
            }
            System.arraycopy(tuple2.f1, 0, fArr, i * intValue2, intValue2);
            i++;
        }
        LOG.info("taskId: {}, trainInputSize: {}", Integer.valueOf(getPatitionId()), Integer.valueOf(i2));
        LOG.info("taskId: {}, trainOutputSize: {}", Integer.valueOf(getPatitionId()), Integer.valueOf(list.size() - i2));
        new TrainSubSet(intValue, longArray2, this.params, getPatitionId(), null, primitive).train(longValue, list2, fArr, map);
        for (int i3 = 0; i3 < i; i3++) {
            System.arraycopy(fArr, i3 * intValue2, list.get(i3).f1, 0, intValue2);
        }
        LOG.info("taskId: {}, trainTime: {}", Integer.valueOf(getPatitionId()), Double.valueOf((System.nanoTime() - nanoTime) / 1000000.0d));
        return list;
    }
}
