package com.alibaba.alink.operator.common.classification.ann;

import com.alibaba.alink.common.linalg.BLAS;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.operator.common.tree.Criteria;

/* loaded from: input_file:com/alibaba/alink/operator/common/classification/ann/AffineLayerModel.class */
public class AffineLayerModel extends LayerModel {
    private final DenseMatrix w;
    private final DenseVector b;
    private final DenseMatrix gradw;
    private final DenseVector gradb;
    private transient DenseVector ones = null;

    public AffineLayerModel(AffineLayer affineLayer) {
        this.w = new DenseMatrix(affineLayer.numIn, affineLayer.numOut);
        this.b = new DenseVector(affineLayer.numOut);
        this.gradw = new DenseMatrix(affineLayer.numIn, affineLayer.numOut);
        this.gradb = new DenseVector(affineLayer.numOut);
    }

    private void unpack(DenseVector denseVector, int i, DenseMatrix denseMatrix, DenseVector denseVector2) {
        int i2 = 0;
        for (int i3 = 0; i3 < this.w.numRows(); i3++) {
            for (int i4 = 0; i4 < this.w.numCols(); i4++) {
                denseMatrix.set(i3, i4, denseVector.get(i + i2));
                i2++;
            }
        }
        for (int i5 = 0; i5 < this.b.size(); i5++) {
            denseVector2.set(i5, denseVector.get(i + i2));
            i2++;
        }
    }

    private void pack(DenseVector denseVector, int i, DenseMatrix denseMatrix, DenseVector denseVector2) {
        int i2 = 0;
        for (int i3 = 0; i3 < this.w.numRows(); i3++) {
            for (int i4 = 0; i4 < this.w.numCols(); i4++) {
                denseVector.set(i + i2, denseMatrix.get(i3, i4));
                i2++;
            }
        }
        for (int i5 = 0; i5 < this.b.size(); i5++) {
            denseVector.set(i + i2, denseVector2.get(i5));
            i2++;
        }
    }

    @Override // com.alibaba.alink.operator.common.classification.ann.LayerModel
    public void resetModel(DenseVector denseVector, int i) {
        unpack(denseVector, i, this.w, this.b);
    }

    @Override // com.alibaba.alink.operator.common.classification.ann.LayerModel
    public void eval(DenseMatrix denseMatrix, DenseMatrix denseMatrix2) {
        int numRows = denseMatrix.numRows();
        for (int i = 0; i < numRows; i++) {
            for (int i2 = 0; i2 < this.b.size(); i2++) {
                denseMatrix2.set(i, i2, this.b.get(i2));
            }
        }
        BLAS.gemm(1.0d, denseMatrix, false, this.w, false, 1.0d, denseMatrix2);
    }

    @Override // com.alibaba.alink.operator.common.classification.ann.LayerModel
    public void computePrevDelta(DenseMatrix denseMatrix, DenseMatrix denseMatrix2, DenseMatrix denseMatrix3) {
        BLAS.gemm(1.0d, denseMatrix, false, this.w, true, Criteria.INVALID_GAIN, denseMatrix3);
    }

    @Override // com.alibaba.alink.operator.common.classification.ann.LayerModel
    public void grad(DenseMatrix denseMatrix, DenseMatrix denseMatrix2, DenseVector denseVector, int i) {
        unpack(denseVector, i, this.gradw, this.gradb);
        int numRows = denseMatrix2.numRows();
        BLAS.gemm(1.0d, denseMatrix2, true, denseMatrix, false, 1.0d, this.gradw);
        if (this.ones == null || this.ones.size() != numRows) {
            this.ones = DenseVector.ones(numRows);
        }
        BLAS.gemv(1.0d, denseMatrix, true, this.ones, 1.0d, this.gradb);
        pack(denseVector, i, this.gradw, this.gradb);
    }
}
