package com.alibaba.alink.operator.stream.onlinelearning.kernel;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.utils.RowCollector;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp;
import com.alibaba.alink.operator.common.fm.FmModelData;
import com.alibaba.alink.operator.common.fm.FmModelDataConverter;
import com.alibaba.alink.operator.common.optim.FmOptimizer;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.onlinelearning.OnlineLearningTrainParams;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/kernel/FmOnlineLearningKernel.class */
public class FmOnlineLearningKernel extends OnlineLearningKernel {
    public FmModelData modelData;
    private BaseFmTrainBatchOp.FmDataFormat nParam;
    private BaseFmTrainBatchOp.FmDataFormat zParam;
    private final transient BaseFmTrainBatchOp.LossFunction lossFunc;
    double[] regular;
    int[] dim;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.alibaba.alink.operator.stream.onlinelearning.kernel.FmOnlineLearningKernel$1, reason: invalid class name */
    /* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/kernel/FmOnlineLearningKernel$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$alibaba$alink$params$onlinelearning$OnlineLearningTrainParams$OptimMethod = new int[OnlineLearningTrainParams.OptimMethod.values().length];

        static {
            try {
                $SwitchMap$com$alibaba$alink$params$onlinelearning$OnlineLearningTrainParams$OptimMethod[OnlineLearningTrainParams.OptimMethod.FTRL.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$onlinelearning$OnlineLearningTrainParams$OptimMethod[OnlineLearningTrainParams.OptimMethod.ADAGRAD.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$onlinelearning$OnlineLearningTrainParams$OptimMethod[OnlineLearningTrainParams.OptimMethod.RMSprop.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$onlinelearning$OnlineLearningTrainParams$OptimMethod[OnlineLearningTrainParams.OptimMethod.ADAM.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$onlinelearning$OnlineLearningTrainParams$OptimMethod[OnlineLearningTrainParams.OptimMethod.SGD.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$onlinelearning$OnlineLearningTrainParams$OptimMethod[OnlineLearningTrainParams.OptimMethod.MOMENTUM.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    public FmOnlineLearningKernel(Params params, boolean z) {
        super(params);
        if (z) {
            this.lossFunc = new BaseFmTrainBatchOp.LogitLoss();
        } else {
            double max = Math.max((-1.0E20d) - 1.0E20d, 1.0d);
            this.lossFunc = new BaseFmTrainBatchOp.SquareLoss((-1.0E20d) + (max * 0.2d), 1.0E20d - (max * 0.2d));
        }
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public int getVectorIdx(TableSchema tableSchema) {
        if (this.modelData.featureColNames != null) {
            return -1;
        }
        return TableUtil.findColIndexWithAssertAndHint(tableSchema, this.modelData.vectorColName);
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public int getLabelIdx(TableSchema tableSchema) {
        return TableUtil.findColIndexWithAssertAndHint(tableSchema, this.modelData.labelColName);
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public int[] getFeatureIndices(TableSchema tableSchema) {
        if (this.modelData.featureColNames != null) {
            return TableUtil.findColIndices(tableSchema, this.modelData.featureColNames);
        }
        return null;
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public Map<Integer, double[]> getGradient() {
        return this.sparseGradient;
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public void calcGradient(Vector vector, Object obj) throws Exception {
        int[] iArr;
        double[] data;
        double d = obj.equals(this.modelData.labelValues[0]) ? 1.0d : Criteria.INVALID_GAIN;
        Tuple2<Double, double[]> calcY = FmOptimizer.calcY(vector, this.modelData.fmModel, this.dim);
        double dldy = this.lossFunc.dldy(d, ((Double) calcY.f0).doubleValue());
        if (vector instanceof SparseVector) {
            iArr = ((SparseVector) vector).getIndices();
            data = ((SparseVector) vector).getValues();
        } else {
            iArr = new int[vector.size()];
            for (int i = 0; i < vector.size(); i++) {
                iArr[i] = i;
            }
            data = ((DenseVector) vector).getData();
        }
        if (this.dim[0] > 0) {
            double d2 = dldy + (this.regular[0] * this.modelData.fmModel.bias);
            if (this.sparseGradient.containsKey(-1)) {
                double[] dArr = this.sparseGradient.get(-1);
                dArr[1] = dArr[1] + 1.0d;
                double[] dArr2 = this.sparseGradient.get(-1);
                dArr2[0] = dArr2[0] + d2;
            } else {
                this.sparseGradient.put(-1, new double[]{d2, 1.0d});
            }
        }
        double[][] dArr3 = this.modelData.fmModel.factors;
        for (int i2 = 0; i2 < iArr.length; i2++) {
            int i3 = iArr[i2];
            if (this.sparseGradient.containsKey(Integer.valueOf(i3))) {
                double[] dArr4 = this.sparseGradient.get(Integer.valueOf(i3));
                int length = dArr4.length - 1;
                dArr4[length] = dArr4[length] + 1.0d;
                if (this.dim[1] > 0) {
                    int i4 = this.dim[2];
                    dArr4[i4] = dArr4[i4] + (dldy * data[i2]) + (this.regular[1] * dArr3[i3][this.dim[2]]);
                }
                if (this.dim[2] > 0) {
                    for (int i5 = 0; i5 < this.dim[2]; i5++) {
                        int i6 = i5;
                        dArr4[i6] = dArr4[i6] + (dldy * data[i2] * (((double[]) calcY.f1)[i5] - (data[i2] * dArr3[i3][i5]))) + (this.regular[2] * dArr3[i3][i5]);
                    }
                }
            } else {
                double[] dArr5 = new double[this.dim[2] + this.dim[1] + 1];
                if (this.dim[1] > 0) {
                    dArr5[this.dim[2]] = (dldy * data[i2]) + (this.regular[1] * dArr3[i3][this.dim[2]]);
                }
                if (this.dim[2] > 0) {
                    for (int i7 = 0; i7 < this.dim[2]; i7++) {
                        dArr5[i7] = (dldy * data[i2] * (((double[]) calcY.f1)[i7] - (data[i2] * dArr3[i3][i7]))) + (this.regular[2] * dArr3[i3][i7]);
                    }
                }
                dArr5[this.dim[2] + this.dim[1]] = 1.0d;
                this.sparseGradient.put(Integer.valueOf(i3), dArr5);
            }
        }
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public void updateModel(Object obj) {
        Map map = (Map) obj;
        if (this.optimMethod.equals(OnlineLearningTrainParams.OptimMethod.ADAM)) {
            this.beta1Power *= this.beta1;
            this.beta2Power *= this.beta2;
        }
        Iterator it = map.keySet().iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            if (intValue == -1) {
                if (!$assertionsDisabled && ((double[]) map.get(Integer.valueOf(intValue))).length != 2) {
                    throw new AssertionError();
                }
                double d = ((double[]) map.get(Integer.valueOf(intValue)))[0] / ((double[]) map.get(Integer.valueOf(intValue)))[1];
                switch (AnonymousClass1.$SwitchMap$com$alibaba$alink$params$onlinelearning$OnlineLearningTrainParams$OptimMethod[this.optimMethod.ordinal()]) {
                    case 1:
                        double sqrt = (Math.sqrt(this.nParam.bias + (d * d)) - Math.sqrt(this.nParam.bias)) / this.alpha;
                        this.zParam.bias += d - (sqrt * this.modelData.fmModel.bias);
                        this.nParam.bias += d * d;
                        if (Math.abs(this.zParam.bias) <= this.l1) {
                            this.modelData.fmModel.bias = Criteria.INVALID_GAIN;
                            break;
                        } else {
                            this.modelData.fmModel.bias = (((this.zParam.bias < Criteria.INVALID_GAIN ? -1 : 1) * this.l1) - this.zParam.bias) / (((this.beta + Math.sqrt(this.nParam.bias)) / this.alpha) + this.l2);
                            break;
                        }
                    case 2:
                        this.nParam.bias += d * d;
                        this.modelData.fmModel.bias -= (this.learningRate * d) / Math.sqrt(this.nParam.bias + 1.0E-8d);
                        break;
                    case 3:
                        this.nParam.bias = (this.gamma * this.nParam.bias) + ((1.0d - this.gamma) * d * d);
                        this.modelData.fmModel.bias -= (this.learningRate * d) / Math.sqrt(this.nParam.bias + 1.0E-8d);
                        break;
                    case 4:
                        this.nParam.bias = (this.beta1 * this.nParam.bias) + ((1.0d - this.beta1) * d);
                        this.zParam.bias = (this.beta2 * this.zParam.bias) + ((1.0d - this.beta2) * d * d);
                        double d2 = this.nParam.bias / (1.0d - this.beta1Power);
                        double d3 = this.zParam.bias / (1.0d - this.beta2Power);
                        this.modelData.fmModel.bias -= (this.learningRate * d2) / (Math.sqrt(d3) + 1.0E-8d);
                        break;
                    case 5:
                        this.modelData.fmModel.bias -= this.learningRate * d;
                        break;
                    case TableUtil.DISPLAY_SIZE /* 6 */:
                        this.nParam.bias = (this.gamma * this.nParam.bias) + (this.learningRate * d);
                        this.modelData.fmModel.bias -= this.nParam.bias;
                        break;
                }
            } else {
                double[] dArr = (double[]) map.get(Integer.valueOf(intValue));
                double[] dArr2 = new double[dArr.length - 1];
                for (int i = 0; i < dArr2.length; i++) {
                    dArr2[i] = dArr[i] / dArr[dArr2.length];
                }
                if (this.dim[1] > 0) {
                    updateModelVal(intValue, this.dim[2], dArr2[this.dim[2]]);
                }
                if (this.dim[2] > 0) {
                    for (int i2 = 0; i2 < this.dim[2]; i2++) {
                        updateModelVal(intValue, i2, dArr2[i2]);
                    }
                }
            }
        }
    }

    private void updateModelVal(int i, int i2, double d) {
        switch (AnonymousClass1.$SwitchMap$com$alibaba$alink$params$onlinelearning$OnlineLearningTrainParams$OptimMethod[this.optimMethod.ordinal()]) {
            case 1:
                double sqrt = (Math.sqrt(this.nParam.factors[i][i2] + (d * d)) - Math.sqrt(this.nParam.factors[i][i2])) / this.alpha;
                double[] dArr = this.zParam.factors[i];
                dArr[i2] = dArr[i2] + (d - (sqrt * this.modelData.fmModel.factors[i][i2]));
                double[] dArr2 = this.nParam.factors[i];
                dArr2[i2] = dArr2[i2] + (d * d);
                if (Math.abs(this.zParam.factors[i][i2]) > this.l1) {
                    this.modelData.fmModel.factors[i][i2] = (((this.zParam.factors[i][i2] < Criteria.INVALID_GAIN ? -1 : 1) * this.l1) - this.zParam.factors[i][i2]) / (((this.beta + Math.sqrt(this.nParam.factors[i][i2])) / this.alpha) + this.l2);
                    break;
                } else {
                    this.modelData.fmModel.factors[i][i2] = 0.0d;
                    break;
                }
            case 2:
                break;
            case 3:
                this.nParam.factors[i][i2] = (this.gamma * this.nParam.factors[i][i2]) + ((1.0d - this.gamma) * d * d);
                double[] dArr3 = this.modelData.fmModel.factors[i];
                dArr3[i2] = dArr3[i2] - ((this.learningRate * d) / Math.sqrt(this.nParam.factors[i][i2] + 1.0E-8d));
                return;
            case 4:
                this.nParam.factors[i][i2] = (this.beta1 * this.nParam.factors[i][i2]) + ((1.0d - this.beta1) * d);
                this.zParam.factors[i][i2] = (this.beta2 * this.zParam.factors[i][i2]) + ((1.0d - this.beta2) * d * d);
                double d2 = this.nParam.factors[i][i2] / (1.0d - this.beta1Power);
                double d3 = this.zParam.factors[i][i2] / (1.0d - this.beta2Power);
                double[] dArr4 = this.modelData.fmModel.factors[i];
                dArr4[i2] = dArr4[i2] - ((this.learningRate * d2) / (Math.sqrt(d3) + 1.0E-8d));
                return;
            case 5:
                double[] dArr5 = this.modelData.fmModel.factors[i];
                dArr5[i2] = dArr5[i2] - (this.learningRate * d);
                return;
            case TableUtil.DISPLAY_SIZE /* 6 */:
                this.nParam.factors[i][i2] = (this.gamma * this.nParam.factors[i][i2]) + (this.learningRate * d);
                double[] dArr6 = this.modelData.fmModel.factors[i];
                dArr6[i2] = dArr6[i2] - this.nParam.factors[i][i2];
                return;
            default:
                return;
        }
        double[] dArr7 = this.nParam.factors[i];
        dArr7[i2] = dArr7[i2] + (d * d);
        double[] dArr8 = this.modelData.fmModel.factors[i];
        dArr8[i2] = dArr8[i2] - ((this.learningRate * d) / Math.sqrt(this.nParam.factors[i][i2] + 1.0E-8d));
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public List<Row> serializeModel() {
        RowCollector rowCollector = new RowCollector();
        new FmModelDataConverter().save2(this.modelData, (Collector<Row>) rowCollector);
        return rowCollector.getRows();
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public void deserializeModel(List<Row> list) {
        this.modelData = new FmModelDataConverter().load(list);
        double[][] dArr = this.modelData.fmModel.factors;
        if (!this.optimMethod.equals(OnlineLearningTrainParams.OptimMethod.SGD)) {
            this.nParam = new BaseFmTrainBatchOp.FmDataFormat(dArr.length, dArr[0].length, this.modelData.dim, Criteria.INVALID_GAIN);
        }
        if (this.optimMethod.equals(OnlineLearningTrainParams.OptimMethod.ADAM) || this.optimMethod.equals(OnlineLearningTrainParams.OptimMethod.FTRL)) {
            this.zParam = new BaseFmTrainBatchOp.FmDataFormat(dArr.length, dArr[0].length, this.modelData.dim, Criteria.INVALID_GAIN);
        }
        this.dim = this.modelData.dim;
        this.regular = this.modelData.regular;
    }

    static {
        $assertionsDisabled = !FmOnlineLearningKernel.class.desiredAssertionStatus();
    }
}
