package com.alibaba.alink.operator.stream.onlinelearning.kernel;

import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.params.onlinelearning.OnlineLearningTrainParams;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/kernel/OnlineLearningKernel.class */
public abstract class OnlineLearningKernel implements Serializable {
    protected final double alpha;
    protected final double beta;
    protected final double l1;
    protected final double l2;
    protected static final double EPS = 1.0E-8d;
    protected final OnlineLearningTrainParams.OptimMethod optimMethod;
    protected double learningRate;
    protected final double gamma;
    protected double beta1;
    protected double beta2;
    protected double beta1Power = 1.0d;
    protected double beta2Power = 1.0d;
    protected final Map<Integer, double[]> sparseGradient = new HashMap();

    public OnlineLearningKernel(Params params) {
        this.alpha = ((Double) params.get(OnlineLearningTrainParams.ALPHA)).doubleValue();
        this.beta = ((Double) params.get(OnlineLearningTrainParams.BETA)).doubleValue();
        this.l1 = ((Double) params.get(OnlineLearningTrainParams.L_1)).doubleValue();
        this.l2 = ((Double) params.get(OnlineLearningTrainParams.L_2)).doubleValue();
        this.optimMethod = (OnlineLearningTrainParams.OptimMethod) params.get(OnlineLearningTrainParams.OPTIM_METHOD);
        if (params.get(OnlineLearningTrainParams.LEARNING_RATE) == null) {
            switch (this.optimMethod) {
                case SGD:
                    this.learningRate = 5.0E-4d;
                    break;
                case ADAM:
                case RMSprop:
                case MOMENTUM:
                    this.learningRate = 0.001d;
                    break;
                case ADAGRAD:
                    this.learningRate = 0.01d;
                    break;
            }
        } else {
            this.learningRate = ((Double) params.get(OnlineLearningTrainParams.LEARNING_RATE)).doubleValue();
        }
        this.gamma = ((Double) params.get(OnlineLearningTrainParams.GAMMA)).doubleValue();
        this.beta1 = ((Double) params.get(OnlineLearningTrainParams.BETA_1)).doubleValue();
        this.beta2 = ((Double) params.get(OnlineLearningTrainParams.BETA_2)).doubleValue();
    }

    public abstract int getVectorIdx(TableSchema tableSchema);

    public abstract int getLabelIdx(TableSchema tableSchema);

    public abstract int[] getFeatureIndices(TableSchema tableSchema);

    public abstract Map<Integer, double[]> getGradient();

    public abstract void calcGradient(Vector vector, Object obj) throws Exception;

    public abstract void updateModel(Object obj);

    public abstract List<Row> serializeModel();

    public abstract void deserializeModel(List<Row> list);
}
