package com.alibaba.alink.operator.stream.onlinelearning.kernel;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.linear.LinearModelDataConverter;
import com.alibaba.alink.operator.common.linear.LinearModelType;
import com.alibaba.alink.params.onlinelearning.OnlineLearningTrainParams;
import java.util.List;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/kernel/SoftmaxOnlineLearningKernel.class */
public class SoftmaxOnlineLearningKernel extends LinearOnlineLearningKernel {
    private int k1;

    public SoftmaxOnlineLearningKernel(Params params) {
        super(params, LinearModelType.LR);
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.LinearOnlineLearningKernel, com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public void calcGradient(Vector vector, Object obj) throws Exception {
        int i = 0;
        int i2 = 0;
        while (true) {
            if (i2 >= this.modelData.labelValues.length) {
                break;
            }
            if (obj.equals(this.modelData.labelValues[i2])) {
                i = i2;
                break;
            }
            i2++;
        }
        if (this.modelData.hasInterceptItem) {
            vector = vector.prefix(1.0d);
        }
        double[] calcPhi = calcPhi(vector, this.modelData.coefVector);
        if (i < this.k1) {
            int i3 = i;
            calcPhi[i3] = calcPhi[i3] - 1.0d;
        }
        int size = this.modelData.coefVector.size() / this.k1;
        if (!(vector instanceof SparseVector)) {
            double[] data = ((DenseVector) vector).getData();
            for (int i4 = 0; i4 < this.k1; i4++) {
                calcPhi[i4] = calcPhi[i4];
                int i5 = i4 * size;
                for (int i6 = 0; i6 < size; i6++) {
                    int i7 = i5 + i6;
                    if (this.sparseGradient.containsKey(Integer.valueOf(i7))) {
                        double[] dArr = this.sparseGradient.get(Integer.valueOf(i7));
                        dArr[0] = dArr[0] + (data[i6] * calcPhi[i4]);
                        double[] dArr2 = this.sparseGradient.get(Integer.valueOf(i7));
                        dArr2[1] = dArr2[1] + 1.0d;
                    } else {
                        this.sparseGradient.put(Integer.valueOf(i7), new double[]{data[i6] * calcPhi[i4], 1.0d});
                    }
                }
            }
            return;
        }
        int[] indices = ((SparseVector) vector).getIndices();
        double[] values = ((SparseVector) vector).getValues();
        for (int i8 = 0; i8 < this.k1; i8++) {
            calcPhi[i8] = calcPhi[i8];
            int i9 = i8 * size;
            for (int i10 = 0; i10 < indices.length; i10++) {
                int i11 = i9 + indices[i10];
                if (this.sparseGradient.containsKey(Integer.valueOf(i11))) {
                    double[] dArr3 = this.sparseGradient.get(Integer.valueOf(i11));
                    dArr3[0] = dArr3[0] + (values[i10] * calcPhi[i8]);
                    double[] dArr4 = this.sparseGradient.get(Integer.valueOf(i11));
                    dArr4[1] = dArr4[1] + 1.0d;
                } else {
                    this.sparseGradient.put(Integer.valueOf(i11), new double[]{values[i10] * calcPhi[i8], 1.0d});
                }
            }
        }
    }

    private double[] calcPhi(Vector vector, DenseVector denseVector) {
        double[] dArr = new double[this.k1];
        int size = denseVector.size() / this.k1;
        double[] data = denseVector.getData();
        double d = 1.0d;
        if (vector instanceof DenseVector) {
            double[] data2 = ((DenseVector) vector).getData();
            for (int i = 0; i < this.k1; i++) {
                double d2 = 0.0d;
                for (int i2 = 0; i2 < size; i2++) {
                    d2 += data2[i2] * data[(i * size) + i2];
                }
                dArr[i] = Math.exp(d2);
                d += dArr[i];
            }
        } else {
            int[] indices = ((SparseVector) vector).getIndices();
            double[] values = ((SparseVector) vector).getValues();
            for (int i3 = 0; i3 < this.k1; i3++) {
                double d3 = 0.0d;
                for (int i4 = 0; i4 < indices.length; i4++) {
                    d3 += values[i4] * data[(i3 * size) + indices[i4]];
                }
                dArr[i3] = Math.exp(d3);
                d += dArr[i3];
            }
        }
        for (int i5 = 0; i5 < this.k1; i5++) {
            int i6 = i5;
            dArr[i6] = dArr[i6] / d;
        }
        return dArr;
    }

    @Override // com.alibaba.alink.operator.stream.onlinelearning.kernel.LinearOnlineLearningKernel, com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel
    public void deserializeModel(List<Row> list) {
        this.modelData = new LinearModelDataConverter().load(list);
        if (!this.optimMethod.equals(OnlineLearningTrainParams.OptimMethod.SGD)) {
            this.nParam = new double[this.modelData.coefVector.size()];
        }
        if (this.optimMethod.equals(OnlineLearningTrainParams.OptimMethod.ADAM) || this.optimMethod.equals(OnlineLearningTrainParams.OptimMethod.FTRL)) {
            this.zParam = new double[this.modelData.coefVector.size()];
        }
        this.k1 = this.modelData.coefVectors.length;
    }
}
