package com.alibaba.alink.operator.batch.huge.line;

import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.io.directreader.DefaultDistributedInfo;
import com.alibaba.alink.operator.common.aps.ApsContext;
import com.alibaba.alink.operator.common.aps.ApsFuncTrain;
import com.alibaba.alink.params.graph.LineParams;
import com.alibaba.alink.params.nlp.HasNegative;
import com.alibaba.alink.params.shared.HasVectorSizeDefaultAs100;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/batch/huge/line/ApsFuncTrainLine.class */
public class ApsFuncTrainLine extends ApsFuncTrain<Number[], float[][]> {
    private static final long serialVersionUID = -2210479949975637732L;
    private final Params params;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/huge/line/ApsFuncTrainLine$TrainRunner.class */
    public static class TrainRunner extends Thread {
        private final Long[] nsPool;
        List<Number[]> edges;
        int seed;
        int negaTime;
        double sampleRatioPerPartition;
        double leaningRate;
        double minRhoRate;
        int vectorSize;
        boolean isOrderOne;
        float[] valueBuffer;
        float[] contextBuffer;
        Map<Long, Integer> modelMapper;

        TrainRunner(Long[] lArr, int i, int i2, int i3, int i4, double d, double d2, double d3, float[] fArr, float[] fArr2, Map<Long, Integer> map, List<Number[]> list) {
            this.nsPool = lArr;
            this.edges = list;
            this.seed = i;
            this.negaTime = i4;
            this.sampleRatioPerPartition = d;
            this.leaningRate = d2;
            this.minRhoRate = d3;
            this.vectorSize = i2;
            this.valueBuffer = fArr;
            this.contextBuffer = fArr2;
            this.modelMapper = map;
            this.isOrderOne = i3 == 1;
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            new LinePullAndTrainOperation(this.negaTime, this.nsPool).train(this.seed, this.leaningRate, this.minRhoRate, this.isOrderOne, this.vectorSize, this.sampleRatioPerPartition, this.valueBuffer, this.contextBuffer, this.modelMapper, this.edges);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/huge/line/ApsFuncTrainLine$TrainSubSet.class */
    public static class TrainSubSet {
        private final Long[] nsPool;
        int negaTime;
        double sampleRatioPerPartition;
        double leaningRate;
        double minRhoRate;
        int vectorSize;
        int threadNum;

        TrainSubSet(Long[] lArr, int i, double d, double d2, double d3, int i2, int i3) {
            this.nsPool = lArr;
            this.negaTime = i;
            this.sampleRatioPerPartition = d;
            this.leaningRate = d2;
            this.minRhoRate = d3;
            this.vectorSize = i2;
            this.threadNum = i3;
        }

        public void train(int i, List<Number[]> list, int i2, float[] fArr, float[] fArr2, Map<Long, Integer> map, int i3) throws InterruptedException {
            Thread[] threadArr = new Thread[this.threadNum];
            DefaultDistributedInfo defaultDistributedInfo = new DefaultDistributedInfo();
            for (int i4 = 0; i4 < this.threadNum; i4++) {
                int startPos = (int) defaultDistributedInfo.startPos(i4, this.threadNum, list.size());
                threadArr[i4] = new TrainRunner(this.nsPool, i + i4, this.vectorSize, i2, this.negaTime, this.sampleRatioPerPartition, this.leaningRate, this.minRhoRate, fArr, fArr2, map, list.subList(startPos, ((int) defaultDistributedInfo.localRowCnt(i4, this.threadNum, list.size())) + startPos));
                threadArr[i4].start();
            }
            for (int i5 = 0; i5 < this.threadNum; i5++) {
                threadArr[i5].join();
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ApsFuncTrainLine(Params params) {
        this.params = params;
    }

    @Override // com.alibaba.alink.operator.common.aps.ApsFuncTrain
    protected List<Tuple2<Long, float[][]>> train(List<Tuple2<Long, float[][]>> list, Map<Long, Integer> map, List<Number[]> list2) throws Exception {
        if (null == this.contextParams) {
            throw new AkUnclassifiedErrorException("Aps server meets RuntimeException when training");
        }
        Long[] longArray = this.contextParams.getLongArray("negBound");
        int intValue = ((Long[]) this.contextParams.get(ApsContext.SEEDS))[getRuntimeContext().getIndexOfThisSubtask()].intValue();
        int intValue2 = ((Integer) this.params.get(HasNegative.NEGATIVE)).intValue();
        int intValue3 = this.params.getIntegerOrDefault("threadNum", 8).intValue();
        double doubleValue = ((Double) this.params.get(LineParams.SAMPLE_RATIO_PER_PARTITION)).doubleValue();
        double doubleValue2 = ((Double) this.params.get(LineParams.RHO)).doubleValue();
        int intValue4 = ((Integer) this.params.get(HasVectorSizeDefaultAs100.VECTOR_SIZE)).intValue();
        int value = ((LineParams.Order) this.params.get(LineParams.ORDER)).getValue();
        double doubleValue3 = ((Double) this.params.get(LineParams.MIN_RHO_RATE)).doubleValue();
        int size = list.size();
        float[] fArr = new float[size * intValue4];
        float[] fArr2 = new float[0];
        if (value == 1) {
            for (int i = 0; i < size; i++) {
                System.arraycopy(((float[][]) list.get(i).f1)[0], 0, fArr, i * intValue4, intValue4);
            }
        } else {
            fArr2 = new float[size * intValue4];
            for (int i2 = 0; i2 < size; i2++) {
                System.arraycopy(((float[][]) list.get(i2).f1)[0], 0, fArr, i2 * intValue4, intValue4);
                System.arraycopy(((float[][]) list.get(i2).f1)[1], 0, fArr2, i2 * intValue4, intValue4);
            }
        }
        new TrainSubSet(longArray, intValue2, doubleValue, doubleValue2, doubleValue3, intValue4, intValue3).train(intValue, list2, value, fArr, fArr2, map, getIterationRuntimeContext().getSuperstepNumber());
        if (value == 1) {
            for (int i3 = 0; i3 < size; i3++) {
                float[] fArr3 = ((float[][]) list.get(i3).f1)[0];
                for (int i4 = 0; i4 < intValue4; i4++) {
                    fArr3[i4] = fArr[(i3 * intValue4) + i4] - fArr3[i4];
                }
            }
        } else {
            for (int i5 = 0; i5 < size; i5++) {
                float[] fArr4 = ((float[][]) list.get(i5).f1)[0];
                float[] fArr5 = ((float[][]) list.get(i5).f1)[1];
                for (int i6 = 0; i6 < intValue4; i6++) {
                    fArr4[i6] = fArr[(i5 * intValue4) + i6] - fArr4[i6];
                    fArr5[i6] = fArr2[(i5 * intValue4) + i6] - fArr5[i6];
                }
            }
        }
        return list;
    }
}
