package com.alibaba.alink.operator.stream.onlinelearning.kernel;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.utils.RowCollector;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.linear.LinearModelData;
import com.alibaba.alink.operator.common.linear.LinearModelDataConverter;
import com.alibaba.alink.operator.common.linear.LinearModelType;
import com.alibaba.alink.operator.common.linear.unarylossfunc.LogLossFunc;
import com.alibaba.alink.operator.common.linear.unarylossfunc.SmoothHingeLossFunc;
import com.alibaba.alink.operator.common.linear.unarylossfunc.SquareLossFunc;
import com.alibaba.alink.operator.common.linear.unarylossfunc.UnaryLossFunc;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.onlinelearning.OnlineLearningTrainParams;
import java.util.List;
import java.util.Map;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/kernel/LinearOnlineLearningKernel.class */
public class LinearOnlineLearningKernel extends OnlineLearningKernel {
    protected LinearModelData modelData;
    protected double[] nParam;
    protected double[] zParam;
    private final UnaryLossFunc lossFunc;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.alibaba.alink.operator.stream.onlinelearning.kernel.LinearOnlineLearningKernel$1, reason: invalid class name */
    /* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/kernel/LinearOnlineLearningKernel$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$alibaba$alink$params$onlinelearning$OnlineLearningTrainParams$OptimMethod = new int[OnlineLearningTrainParams.OptimMethod.values().length];

        static {
            try {
                $SwitchMap$com$alibaba$alink$params$onlinelearning$OnlineLearningTrainParams$OptimMethod[OnlineLearningTrainParams.OptimMethod.FTRL.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$onlinelearning$OnlineLearningTrainParams$OptimMethod[OnlineLearningTrainParams.OptimMethod.ADAGRAD.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$onlinelearning$OnlineLearningTrainParams$OptimMethod[OnlineLearningTrainParams.OptimMethod.RMSprop.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$onlinelearning$OnlineLearningTrainParams$OptimMethod[OnlineLearningTrainParams.OptimMethod.ADAM.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$onlinelearning$OnlineLearningTrainParams$OptimMethod[OnlineLearningTrainParams.OptimMethod.SGD.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$onlinelearning$OnlineLearningTrainParams$OptimMethod[OnlineLearningTrainParams.OptimMethod.MOMENTUM.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            $SwitchMap$com$alibaba$alink$operator$common$linear$LinearModelType = new int[LinearModelType.values().length];
            try {
                $SwitchMap$com$alibaba$alink$operator$common$linear$LinearModelType[LinearModelType.LinearReg.ordinal()] = 1;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$com$alibaba$alink$operator$common$linear$LinearModelType[LinearModelType.SVM.ordinal()] = 2;
            } catch (NoSuchFieldError e8) {
            }
        }
    }

    public LinearOnlineLearningKernel(Params params, LinearModelType linearModelType) {
        super(params);
        switch (linearModelType) {
            case LinearReg:
                this.lossFunc = new SquareLossFunc();
                return;
            case SVM:
                this.lossFunc = new SmoothHingeLossFunc();
                return;
            default:
                this.lossFunc = new LogLossFunc();
                return;
        }
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public int getVectorIdx(TableSchema tableSchema) {
        if (this.modelData.featureNames != null) {
            return -1;
        }
        return TableUtil.findColIndexWithAssertAndHint(tableSchema, this.modelData.vectorColName);
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public int getLabelIdx(TableSchema tableSchema) {
        return TableUtil.findColIndexWithAssertAndHint(tableSchema, this.modelData.labelName);
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public int[] getFeatureIndices(TableSchema tableSchema) {
        if (this.modelData.featureNames != null) {
            return TableUtil.findColIndices(tableSchema, this.modelData.featureNames);
        }
        return null;
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public Map<Integer, double[]> getGradient() {
        return this.sparseGradient;
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public void calcGradient(Vector vector, Object obj) throws Exception {
        double doubleValue = this.modelData.labelValues.length == 2 ? obj.equals(this.modelData.labelValues[0]) ? 1.0d : -1.0d : ((Number) obj).doubleValue();
        if (this.modelData.hasInterceptItem) {
            vector = vector.prefix(1.0d);
        }
        double derivative = this.lossFunc.derivative(this.modelData.coefVector.dot(vector), doubleValue);
        if (vector instanceof DenseVector) {
            DenseVector denseVector = (DenseVector) vector;
            for (int i = 0; i < this.modelData.coefVector.size(); i++) {
                if (this.sparseGradient.containsKey(Integer.valueOf(i))) {
                    double[] dArr = this.sparseGradient.get(Integer.valueOf(i));
                    dArr[0] = dArr[0] + (derivative * denseVector.getData()[i]);
                    double[] dArr2 = this.sparseGradient.get(Integer.valueOf(i));
                    dArr2[1] = dArr2[1] + 1.0d;
                } else {
                    this.sparseGradient.put(Integer.valueOf(i), new double[]{denseVector.getData()[i], 1.0d});
                }
            }
            return;
        }
        SparseVector sparseVector = (SparseVector) vector;
        for (int i2 = 0; i2 < sparseVector.getIndices().length; i2++) {
            int i3 = sparseVector.getIndices()[i2];
            if (this.sparseGradient.containsKey(Integer.valueOf(i3))) {
                double[] dArr3 = this.sparseGradient.get(Integer.valueOf(i3));
                dArr3[0] = dArr3[0] + (derivative * sparseVector.getValues()[i2]);
                double[] dArr4 = this.sparseGradient.get(Integer.valueOf(i3));
                dArr4[1] = dArr4[1] + 1.0d;
            } else {
                this.sparseGradient.put(Integer.valueOf(i3), new double[]{derivative * sparseVector.getValues()[i2], 1.0d});
            }
        }
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public void updateModel(Object obj) {
        Map map = (Map) obj;
        int[] iArr = new int[map.size()];
        double[] dArr = new double[map.size()];
        int i = 0;
        for (Integer num : map.keySet()) {
            iArr[i] = num.intValue();
            double[] dArr2 = (double[]) map.get(num);
            int i2 = i;
            i++;
            dArr[i2] = dArr2[0] / dArr2[1];
        }
        if (this.optimMethod.equals(OnlineLearningTrainParams.OptimMethod.ADAM)) {
            this.beta1Power *= this.beta1;
            this.beta2Power *= this.beta2;
        }
        for (int i3 = 0; i3 < iArr.length; i3++) {
            updateModelVal(iArr[i3], dArr[i3]);
        }
    }

    private void updateModelVal(int i, double d) {
        switch (AnonymousClass1.$SwitchMap$com$alibaba$alink$params$onlinelearning$OnlineLearningTrainParams$OptimMethod[this.optimMethod.ordinal()]) {
            case 1:
                double sqrt = (Math.sqrt(this.nParam[i] + (d * d)) - Math.sqrt(this.nParam[i])) / this.alpha;
                double[] dArr = this.zParam;
                dArr[i] = dArr[i] + (d - (sqrt * this.modelData.coefVector.get(i)));
                double[] dArr2 = this.nParam;
                dArr2[i] = dArr2[i] + (d * d);
                if (Math.abs(this.zParam[i]) <= this.l1) {
                    this.modelData.coefVector.set(i, Criteria.INVALID_GAIN);
                    return;
                } else {
                    this.modelData.coefVector.set(i, (((this.zParam[i] < Criteria.INVALID_GAIN ? -1 : 1) * this.l1) - this.zParam[i]) / (((this.beta + Math.sqrt(this.nParam[i])) / this.alpha) + this.l2));
                    return;
                }
            case 2:
                double[] dArr3 = this.nParam;
                dArr3[i] = dArr3[i] + (d * d);
                this.modelData.coefVector.set(i, this.modelData.coefVector.get(i) - ((this.learningRate * d) / Math.sqrt(this.nParam[i] + 1.0E-8d)));
                return;
            case 3:
                this.nParam[i] = (this.gamma * this.nParam[i]) + ((1.0d - this.gamma) * d * d);
                this.modelData.coefVector.set(i, this.modelData.coefVector.get(i) - ((this.learningRate * d) / Math.sqrt(this.nParam[i] + 1.0E-8d)));
                return;
            case 4:
                this.nParam[i] = (this.beta1 * this.nParam[i]) + ((1.0d - this.beta1) * d);
                this.zParam[i] = (this.beta2 * this.zParam[i]) + ((1.0d - this.beta2) * d * d);
                this.modelData.coefVector.set(i, this.modelData.coefVector.get(i) - ((this.learningRate * (this.nParam[i] / (1.0d - this.beta1Power))) / (Math.sqrt(this.zParam[i] / (1.0d - this.beta2Power)) + 1.0E-8d)));
                return;
            case 5:
                this.modelData.coefVector.set(i, this.modelData.coefVector.get(i) - (this.learningRate * d));
                return;
            case TableUtil.DISPLAY_SIZE /* 6 */:
                this.nParam[i] = (this.gamma * this.nParam[i]) + (this.learningRate * d);
                this.modelData.coefVector.set(i, this.modelData.coefVector.get(i) - this.nParam[i]);
                return;
            default:
                return;
        }
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public List<Row> serializeModel() {
        RowCollector rowCollector = new RowCollector();
        new LinearModelDataConverter().save(this.modelData, rowCollector);
        return rowCollector.getRows();
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public void deserializeModel(List<Row> list) {
        this.modelData = new LinearModelDataConverter().load(list);
        if (!this.optimMethod.equals(OnlineLearningTrainParams.OptimMethod.SGD)) {
            this.nParam = new double[this.modelData.coefVector.size()];
        }
        if (this.optimMethod.equals(OnlineLearningTrainParams.OptimMethod.ADAM) || this.optimMethod.equals(OnlineLearningTrainParams.OptimMethod.FTRL)) {
            this.zParam = new double[this.modelData.coefVector.size()];
        }
    }
}
