package com.alibaba.alink.operator.common.optim;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.CompareCriterionFunction;
import com.alibaba.alink.common.comqueue.CompleteResultFunction;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.comqueue.IterativeComQueue;
import com.alibaba.alink.common.comqueue.communication.AllReduce;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp;
import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.recommendation.FmTrainParams;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Random;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/optim/FmOptimizer.class */
public class FmOptimizer {
    private final Params params;
    private final DataSet<Tuple3<Double, Double, Vector>> trainData;
    private final double[] lambda;
    private static final int BLOCK_SIZE = 1000;
    protected DataSet<BaseFmTrainBatchOp.FmDataFormat> fmModel = null;
    private final int[] dim = new int[3];

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/FmOptimizer$CalcLossAndEvaluation.class */
    public static class CalcLossAndEvaluation extends ComputeFunction {
        private static final long serialVersionUID = 1276524768860519162L;
        private final int[] dim;
        private double[] y;
        private final BaseFmTrainBatchOp.LossFunction lossFunc;
        private final BaseFmTrainBatchOp.Task task;

        public CalcLossAndEvaluation(int[] iArr, BaseFmTrainBatchOp.Task task) {
            this.dim = iArr;
            this.task = task;
            if (!this.task.equals(BaseFmTrainBatchOp.Task.REGRESSION)) {
                this.lossFunc = new BaseFmTrainBatchOp.LogitLoss();
            } else {
                double max = Math.max(1.0E20d - (-1.0E20d), 1.0d);
                this.lossFunc = new BaseFmTrainBatchOp.SquareLoss(1.0E20d + (max * 0.2d), (-1.0E20d) - (max * 0.2d));
            }
        }

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            double[] dArr = (double[]) comContext.getObj(OptimVariable.lossAucAllReduce);
            if (dArr == null) {
                dArr = new double[4];
                comContext.putObj(OptimVariable.lossAucAllReduce, dArr);
            }
            List list = (List) comContext.getObj(OptimVariable.fmTrainData);
            if (this.y == null) {
                this.y = new double[list.size()];
            }
            BaseFmTrainBatchOp.FmDataFormat fmDataFormat = (BaseFmTrainBatchOp.FmDataFormat) ((List) comContext.getObj(OptimVariable.fmModel)).get(0);
            Arrays.fill(this.y, Criteria.INVALID_GAIN);
            for (int i = 0; i < list.size(); i++) {
                this.y[i] = ((Double) FmOptimizer.calcY((Vector) ((Tuple3) list.get(i)).f2, fmDataFormat, this.dim).f0).doubleValue();
            }
            double d = 0.0d;
            for (int i2 = 0; i2 < this.y.length; i2++) {
                d += this.lossFunc.l(((Double) ((Tuple3) list.get(i2)).f1).doubleValue(), this.y[i2]);
            }
            if (this.task.equals(BaseFmTrainBatchOp.Task.REGRESSION)) {
                double d2 = 0.0d;
                double d3 = 0.0d;
                for (int i3 = 0; i3 < this.y.length; i3++) {
                    double doubleValue = this.y[i3] - ((Double) ((Tuple3) list.get(i3)).f1).doubleValue();
                    d2 += Math.abs(doubleValue);
                    d3 += doubleValue * doubleValue;
                }
                dArr[2] = d2;
                dArr[3] = d3;
            } else {
                Integer[] numArr = new Integer[this.y.length];
                double d4 = 0.0d;
                for (int i4 = 0; i4 < this.y.length; i4++) {
                    numArr[i4] = Integer.valueOf(i4);
                    if (this.y[i4] > Criteria.INVALID_GAIN && ((Double) ((Tuple3) list.get(i4)).f1).doubleValue() > 0.5d) {
                        d4 += 1.0d;
                    }
                    if (this.y[i4] < Criteria.INVALID_GAIN && ((Double) ((Tuple3) list.get(i4)).f1).doubleValue() < 0.5d) {
                        d4 += 1.0d;
                    }
                }
                Arrays.sort(numArr, new Comparator<Integer>() { // from class: com.alibaba.alink.operator.common.optim.FmOptimizer.CalcLossAndEvaluation.1
                    @Override // java.util.Comparator
                    public int compare(Integer num, Integer num2) {
                        return Double.compare(CalcLossAndEvaluation.this.y[num.intValue()], CalcLossAndEvaluation.this.y[num2.intValue()]);
                    }
                });
                int i5 = 0;
                int i6 = 0;
                double d5 = 0.0d;
                for (int i7 = 0; i7 < numArr.length; i7++) {
                    int intValue = numArr[i7].intValue();
                    int i8 = i7 + 1;
                    if (((Double) ((Tuple3) list.get(intValue)).f1).doubleValue() > 0.5d) {
                        i5++;
                        d5 += i8;
                    } else {
                        i6++;
                    }
                }
                if (i5 == 0 || i6 == 0) {
                    dArr[2] = 0.0d;
                } else {
                    dArr[2] = (d5 - ((0.5d * i5) * (i5 + 1.0d))) / (i5 * i6);
                }
                dArr[3] = d4;
            }
            dArr[0] = d;
            dArr[1] = this.y.length;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/FmOptimizer$FmIterTermination.class */
    public static class FmIterTermination extends CompareCriterionFunction {
        private static final long serialVersionUID = 2410437704683855923L;
        private double oldLoss;
        private final double epsilon;
        private final BaseFmTrainBatchOp.Task task;
        private final int maxIter;
        private final int batchSize;
        private Long oldTime = Long.valueOf(System.currentTimeMillis());

        public FmIterTermination(Params params) {
            this.maxIter = ((Integer) params.get(FmTrainParams.NUM_EPOCHS)).intValue();
            this.epsilon = ((Double) params.get(FmTrainParams.EPSILON)).doubleValue();
            this.task = (BaseFmTrainBatchOp.Task) params.get(ModelParamName.TASK);
            this.batchSize = ((Integer) params.get(FmTrainParams.MINIBATCH_SIZE)).intValue();
        }

        @Override // com.alibaba.alink.common.comqueue.CompareCriterionFunction
        public boolean calc(ComContext comContext) {
            int size = ((List) comContext.getObj(OptimVariable.fmTrainData)).size();
            int i = (this.batchSize == -1 || this.batchSize > size) ? this.maxIter : ((size / this.batchSize) + 1) * this.maxIter;
            double[] dArr = (double[]) comContext.getObj(OptimVariable.convergenceInfo);
            if (dArr == null) {
                dArr = new double[i * 3];
                comContext.putObj(OptimVariable.convergenceInfo, dArr);
            }
            int stepNo = comContext.getStepNo() - 1;
            double[] dArr2 = (double[]) comContext.getObj(OptimVariable.lossAucAllReduce);
            dArr[3 * stepNo] = dArr2[0] / dArr2[1];
            dArr[(3 * stepNo) + 2] = dArr2[3] / dArr2[1];
            if (this.task.equals(BaseFmTrainBatchOp.Task.BINARY_CLASSIFICATION)) {
                dArr[(3 * stepNo) + 1] = dArr2[2] / comContext.getNumTask();
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.println("step : " + stepNo + " loss : " + (dArr2[0] / dArr2[1]) + "  auc : " + (dArr2[2] / comContext.getNumTask()) + " accuracy : " + (dArr2[3] / dArr2[1]) + " time : " + (System.currentTimeMillis() - this.oldTime.longValue()));
                    this.oldTime = Long.valueOf(System.currentTimeMillis());
                }
            } else {
                dArr[(3 * stepNo) + 1] = dArr2[2] / dArr2[1];
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.println("step : " + stepNo + " loss : " + (dArr2[0] / dArr2[1]) + "  mae : " + (dArr2[2] / dArr2[1]) + " mse : " + (dArr2[3] / dArr2[1]) + " time : " + (System.currentTimeMillis() - this.oldTime.longValue()));
                    this.oldTime = Long.valueOf(System.currentTimeMillis());
                }
            }
            if (comContext.getStepNo() == i || Math.abs(this.oldLoss - (dArr2[0] / dArr2[1])) / this.oldLoss < this.epsilon) {
                return true;
            }
            this.oldLoss = dArr2[0] / dArr2[1];
            return false;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/FmOptimizer$OutputFmModel.class */
    public static class OutputFmModel extends CompleteResultFunction {
        private static final long serialVersionUID = 727259322769437038L;

        @Override // com.alibaba.alink.common.comqueue.CompleteResultFunction
        public List<Row> calc(ComContext comContext) {
            if (comContext.getTaskId() != 0) {
                return null;
            }
            double[] dArr = (double[]) comContext.getObj(OptimVariable.convergenceInfo);
            BaseFmTrainBatchOp.FmDataFormat fmDataFormat = (BaseFmTrainBatchOp.FmDataFormat) ((List) comContext.getObj(OptimVariable.fmModel)).get(0);
            ArrayList arrayList = new ArrayList();
            double[][] dArr2 = fmDataFormat.factors;
            int length = (dArr2.length / 1000) + (dArr2.length % 1000 == 0 ? 0 : 1);
            double[] dArr3 = new double[1000];
            for (int i = 0; i < length - 1; i++) {
                System.arraycopy(dArr2, i * 1000, dArr3, 0, 1000);
                arrayList.add(Row.of(new Object[]{Integer.valueOf(i), Integer.valueOf(dArr2.length), JsonConverter.toJson(dArr3)}));
            }
            int length2 = dArr2.length - (1000 * (length - 1));
            double[] dArr4 = new double[length2];
            System.arraycopy(dArr2, (length - 1) * 1000, dArr4, 0, length2);
            arrayList.add(Row.of(new Object[]{Integer.valueOf(length - 1), Integer.valueOf(dArr2.length), JsonConverter.toJson(dArr4)}));
            fmDataFormat.factors = (double[][]) null;
            arrayList.add(Row.of(new Object[]{-2, JsonConverter.toJson(fmDataFormat)}));
            arrayList.add(Row.of(new Object[]{-1, JsonConverter.toJson(dArr)}));
            return arrayList;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/FmOptimizer$ParseRowModel.class */
    public static class ParseRowModel extends RichMapPartitionFunction<Row, Tuple2<BaseFmTrainBatchOp.FmDataFormat, double[]>> {
        private static final long serialVersionUID = -2078134573230730223L;

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r0v38, types: [double[]] */
        public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<BaseFmTrainBatchOp.FmDataFormat, double[]>> collector) throws Exception {
            if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                BaseFmTrainBatchOp.FmDataFormat fmDataFormat = new BaseFmTrainBatchOp.FmDataFormat();
                double[] dArr = new double[0];
                double[][] dArr2 = (double[][]) null;
                for (Row row : iterable) {
                    if (((Integer) row.getField(0)).intValue() >= 0) {
                        if (dArr2 == null) {
                            dArr2 = new double[((Integer) row.getField(1)).intValue()];
                        }
                        int intValue = ((Integer) row.getField(0)).intValue();
                        double[][] dArr3 = (double[][]) JsonConverter.fromJson((String) row.getField(2), double[][].class);
                        System.arraycopy(dArr3, 0, dArr2, 1000 * intValue, dArr3.length);
                    } else if (((Integer) row.getField(0)).intValue() == -1) {
                        dArr = (double[]) JsonConverter.fromJson((String) row.getField(1), double[].class);
                    } else if (((Integer) row.getField(0)).intValue() == -2) {
                        fmDataFormat = (BaseFmTrainBatchOp.FmDataFormat) JsonConverter.fromJson((String) row.getField(1), BaseFmTrainBatchOp.FmDataFormat.class);
                    }
                }
                fmDataFormat.factors = dArr2;
                collector.collect(Tuple2.of(fmDataFormat, dArr));
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/FmOptimizer$UpdateGlobalModel.class */
    public static class UpdateGlobalModel extends ComputeFunction {
        private static final long serialVersionUID = 4584059654350995646L;
        private final int[] dim;

        public UpdateGlobalModel(int[] iArr) {
            this.dim = iArr;
        }

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            double[] dArr = (double[]) comContext.getObj(OptimVariable.factorAllReduce);
            BaseFmTrainBatchOp.FmDataFormat fmDataFormat = (BaseFmTrainBatchOp.FmDataFormat) comContext.getObj(OptimVariable.sigmaGii);
            BaseFmTrainBatchOp.FmDataFormat fmDataFormat2 = (BaseFmTrainBatchOp.FmDataFormat) ((List) comContext.getObj(OptimVariable.fmModel)).get(0);
            int length = (dArr.length - (2 * this.dim[0])) / (((2 * this.dim[2]) + (2 * this.dim[1])) + 1);
            int i = this.dim[2] + this.dim[1];
            for (int i2 = 0; i2 < length; i2++) {
                double d = dArr[(2 * length * i) + i2];
                if (d > Criteria.INVALID_GAIN) {
                    for (int i3 = 0; i3 < this.dim[2]; i3++) {
                        fmDataFormat2.factors[i2][i3] = dArr[(i2 * this.dim[2]) + i3] / d;
                        fmDataFormat.factors[i2][i3] = dArr[((length + i2) * this.dim[2]) + i3] / d;
                    }
                    if (this.dim[1] > 0) {
                        fmDataFormat2.factors[i2][this.dim[2]] = dArr[((length * this.dim[2]) * 2) + i2] / d;
                        fmDataFormat.factors[i2][this.dim[2]] = dArr[(length * ((this.dim[2] * 2) + 1)) + i2] / d;
                    }
                }
            }
            if (this.dim[0] > 0) {
                fmDataFormat2.bias = dArr[dArr.length - 2] / comContext.getNumTask();
                fmDataFormat.bias = dArr[dArr.length - 1] / comContext.getNumTask();
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/FmOptimizer$UpdateLocalModel.class */
    public static class UpdateLocalModel extends ComputeFunction {
        private static final long serialVersionUID = 5331512619834061299L;
        private final int[] dim;
        private final double[] lambda;
        private final double learnRate;
        private int vectorSize;
        private int batchSize;
        private final BaseFmTrainBatchOp.LossFunction lossFunc;
        private final Random rand = new Random(2020);

        public UpdateLocalModel(int[] iArr, double[] dArr, Params params) {
            this.lambda = dArr;
            this.dim = iArr;
            BaseFmTrainBatchOp.Task task = (BaseFmTrainBatchOp.Task) params.get(ModelParamName.TASK);
            this.learnRate = ((Double) params.get(FmTrainParams.LEARN_RATE)).doubleValue();
            this.batchSize = ((Integer) params.get(FmTrainParams.MINIBATCH_SIZE)).intValue();
            if (!task.equals(BaseFmTrainBatchOp.Task.REGRESSION)) {
                this.lossFunc = new BaseFmTrainBatchOp.LogitLoss();
            } else {
                double max = Math.max(1.0E20d - (-1.0E20d), 1.0d);
                this.lossFunc = new BaseFmTrainBatchOp.SquareLoss(1.0E20d + (max * 0.2d), (-1.0E20d) - (max * 0.2d));
            }
        }

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            ArrayList arrayList = (ArrayList) comContext.getObj(OptimVariable.fmTrainData);
            if (this.batchSize == -1) {
                this.batchSize = arrayList.size();
            }
            BaseFmTrainBatchOp.FmDataFormat fmDataFormat = (BaseFmTrainBatchOp.FmDataFormat) comContext.getObj(OptimVariable.sigmaGii);
            BaseFmTrainBatchOp.FmDataFormat fmDataFormat2 = (BaseFmTrainBatchOp.FmDataFormat) ((List) comContext.getObj(OptimVariable.fmModel)).get(0);
            double[] dArr = (double[]) comContext.getObj(OptimVariable.weights);
            if (dArr == null) {
                this.vectorSize = fmDataFormat2.factors.length;
                dArr = new double[this.vectorSize];
                comContext.putObj(OptimVariable.weights, dArr);
            } else {
                Arrays.fill(dArr, Criteria.INVALID_GAIN);
            }
            if (fmDataFormat == null) {
                fmDataFormat = new BaseFmTrainBatchOp.FmDataFormat(this.vectorSize, this.dim, Criteria.INVALID_GAIN);
                comContext.putObj(OptimVariable.sigmaGii, fmDataFormat);
            }
            updateFactors(arrayList, fmDataFormat2, this.learnRate, fmDataFormat, dArr);
            double[] dArr2 = (double[]) comContext.getObj(OptimVariable.factorAllReduce);
            if (dArr2 == null) {
                dArr2 = new double[(this.vectorSize * (this.dim[1] + this.dim[2]) * 2) + this.vectorSize + (2 * this.dim[0])];
                comContext.putObj(OptimVariable.factorAllReduce, dArr2);
            } else {
                Arrays.fill(dArr2, Criteria.INVALID_GAIN);
            }
            for (int i = 0; i < this.vectorSize; i++) {
                for (int i2 = 0; i2 < this.dim[2]; i2++) {
                    dArr2[(i * this.dim[2]) + i2] = fmDataFormat2.factors[i][i2] * dArr[i];
                    dArr2[((this.vectorSize + i) * this.dim[2]) + i2] = fmDataFormat.factors[i][i2] * dArr[i];
                }
                if (this.dim[1] > 0) {
                    dArr2[(this.vectorSize * this.dim[2] * 2) + i] = fmDataFormat2.factors[i][this.dim[2]] * dArr[i];
                    dArr2[(this.vectorSize * ((this.dim[2] * 2) + this.dim[1])) + i] = fmDataFormat.factors[i][this.dim[2]] * dArr[i];
                }
                dArr2[(this.vectorSize * (this.dim[2] + this.dim[1]) * 2) + i] = dArr[i];
            }
            if (this.dim[0] > 0) {
                dArr2[this.vectorSize * (((this.dim[2] + this.dim[1]) * 2) + 1)] = fmDataFormat2.bias;
                dArr2[(this.vectorSize * (((this.dim[2] + this.dim[1]) * 2) + 1)) + 1] = fmDataFormat.bias;
            }
        }

        private void updateFactors(List<Tuple3<Double, Double, Vector>> list, BaseFmTrainBatchOp.FmDataFormat fmDataFormat, double d, BaseFmTrainBatchOp.FmDataFormat fmDataFormat2, double[] dArr) {
            int[] iArr;
            double[] data;
            for (int i = 0; i < this.batchSize * 2; i++) {
                Tuple3<Double, Double, Vector> tuple3 = list.get(this.rand.nextInt(list.size()));
                Tuple2<Double, double[]> calcY = FmOptimizer.calcY((Vector) tuple3.f2, fmDataFormat, this.dim);
                double dldy = this.lossFunc.dldy(((Double) tuple3.f1).doubleValue(), ((Double) calcY.f0).doubleValue());
                if (tuple3.f2 instanceof SparseVector) {
                    iArr = ((SparseVector) tuple3.f2).getIndices();
                    data = ((SparseVector) tuple3.f2).getValues();
                } else {
                    iArr = new int[((Vector) tuple3.f2).size()];
                    for (int i2 = 0; i2 < ((Vector) tuple3.f2).size(); i2++) {
                        iArr[i2] = i2;
                    }
                    data = ((DenseVector) tuple3.f2).getData();
                }
                if (this.dim[0] > 0) {
                    double d2 = dldy + (this.lambda[0] * fmDataFormat.bias);
                    fmDataFormat2.bias += d2 * d2;
                    fmDataFormat.bias += ((-d) * d2) / Math.sqrt(fmDataFormat2.bias + 1.0E-8d);
                }
                for (int i3 = 0; i3 < iArr.length; i3++) {
                    int i4 = iArr[i3];
                    dArr[i4] = dArr[i4] + ((Double) tuple3.f0).doubleValue();
                    for (int i5 = 0; i5 < this.dim[2]; i5++) {
                        double d3 = (dldy * data[i3] * (((double[]) calcY.f1)[i5] - (data[i3] * fmDataFormat.factors[i4][i5]))) + (this.lambda[2] * fmDataFormat.factors[i4][i5]);
                        double[] dArr2 = fmDataFormat2.factors[i4];
                        int i6 = i5;
                        dArr2[i6] = dArr2[i6] + (d3 * d3);
                        double[] dArr3 = fmDataFormat.factors[i4];
                        int i7 = i5;
                        dArr3[i7] = dArr3[i7] + (((-d) * d3) / Math.sqrt(fmDataFormat2.factors[i4][i5] + 1.0E-8d));
                    }
                    if (this.dim[1] > 0) {
                        double d4 = (dldy * data[i3]) + (this.lambda[1] * fmDataFormat.factors[i4][this.dim[2]]);
                        double[] dArr4 = fmDataFormat2.factors[i4];
                        int i8 = this.dim[2];
                        dArr4[i8] = dArr4[i8] + (d4 * d4);
                        double[] dArr5 = fmDataFormat.factors[i4];
                        int i9 = this.dim[2];
                        dArr5[i9] = dArr5[i9] + (((-d4) * d) / Math.sqrt(fmDataFormat2.factors[i4][this.dim[2]] + 1.0E-8d));
                    }
                }
            }
        }
    }

    public FmOptimizer(DataSet<Tuple3<Double, Double, Vector>> dataSet, Params params) {
        this.params = params;
        this.trainData = dataSet;
        this.dim[0] = ((Boolean) params.get(FmTrainParams.WITH_INTERCEPT)).booleanValue() ? 1 : 0;
        this.dim[1] = ((Boolean) params.get(FmTrainParams.WITH_LINEAR_ITEM)).booleanValue() ? 1 : 0;
        this.dim[2] = ((Integer) params.get(FmTrainParams.NUM_FACTOR)).intValue();
        this.lambda = new double[3];
        this.lambda[0] = ((Double) params.get(FmTrainParams.LAMBDA_0)).doubleValue();
        this.lambda[1] = ((Double) params.get(FmTrainParams.LAMBDA_1)).doubleValue();
        this.lambda[2] = ((Double) params.get(FmTrainParams.LAMBDA_2)).doubleValue();
    }

    public void setWithInitFactors(DataSet<BaseFmTrainBatchOp.FmDataFormat> dataSet) {
        this.fmModel = dataSet;
    }

    public DataSet<Tuple2<BaseFmTrainBatchOp.FmDataFormat, double[]>> optimize() {
        return new IterativeComQueue().initWithPartitionedData(OptimVariable.fmTrainData, this.trainData).initWithBroadcastData(OptimVariable.fmModel, this.fmModel).add(new UpdateLocalModel(this.dim, this.lambda, this.params)).add(new AllReduce(OptimVariable.factorAllReduce)).add(new UpdateGlobalModel(this.dim)).add(new CalcLossAndEvaluation(this.dim, (BaseFmTrainBatchOp.Task) this.params.get(ModelParamName.TASK))).add(new AllReduce(OptimVariable.lossAucAllReduce)).setCompareCriterionOfNode0((CompareCriterionFunction) new FmIterTermination(this.params)).closeWith(new OutputFmModel()).setMaxIter(Integer.MAX_VALUE).exec().mapPartition(new ParseRowModel());
    }

    public static Tuple2<Double, double[]> calcY(Vector vector, BaseFmTrainBatchOp.FmDataFormat fmDataFormat, int[] iArr) {
        int[] iArr2;
        double[] data;
        if (vector instanceof SparseVector) {
            iArr2 = ((SparseVector) vector).getIndices();
            data = ((SparseVector) vector).getValues();
        } else {
            iArr2 = new int[vector.size()];
            for (int i = 0; i < vector.size(); i++) {
                iArr2[i] = i;
            }
            data = ((DenseVector) vector).getData();
        }
        double[] dArr = new double[iArr[2]];
        double[] dArr2 = new double[iArr[2]];
        double d = iArr[0] > 0 ? Criteria.INVALID_GAIN + fmDataFormat.bias : 0.0d;
        for (int i2 = 0; i2 < iArr2.length; i2++) {
            int i3 = iArr2[i2];
            double d2 = data[i2];
            if (iArr[1] > 0) {
                d += d2 * fmDataFormat.factors[i3][iArr[2]];
            }
            for (int i4 = 0; i4 < iArr[2]; i4++) {
                double d3 = d2 * fmDataFormat.factors[i3][i4];
                int i5 = i4;
                dArr[i5] = dArr[i5] + d3;
                int i6 = i4;
                dArr2[i6] = dArr2[i6] + (d3 * d3);
            }
        }
        for (int i7 = 0; i7 < iArr[2]; i7++) {
            d += 0.5d * ((dArr[i7] * dArr[i7]) - dArr2[i7]);
        }
        return Tuple2.of(Double.valueOf(d), dArr);
    }
}
