package com.alibaba.alink.operator.common.classification.ann;

import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:com/alibaba/alink/operator/common/classification/ann/FeedForwardModel.class */
public class FeedForwardModel extends TopologyModel {
    private static final long serialVersionUID = 6320266940893689929L;
    private final List<Layer> layers;
    private final List<LayerModel> layerModels;
    private transient List<DenseMatrix> deltas = null;

    public FeedForwardModel(List<Layer> list) {
        this.layers = list;
        this.layerModels = new ArrayList(list.size());
        Iterator<Layer> it = list.iterator();
        while (it.hasNext()) {
            this.layerModels.add(it.next().createModel());
        }
    }

    @Override // com.alibaba.alink.operator.common.classification.ann.TopologyModel
    public void resetModel(DenseVector denseVector) {
        int i = 0;
        for (int i2 = 0; i2 < this.layers.size(); i2++) {
            this.layerModels.get(i2).resetModel(denseVector, i);
            i += this.layers.get(i2).getWeightSize();
        }
    }

    @Override // com.alibaba.alink.operator.common.classification.ann.TopologyModel
    public List<DenseMatrix> forward(DenseMatrix denseMatrix, boolean z) {
        int numRows = denseMatrix.numRows();
        ArrayList arrayList = null;
        if (0 == 0 || ((DenseMatrix) arrayList.get(0)).numRows() != numRows) {
            arrayList = new ArrayList(this.layers.size());
            int numCols = denseMatrix.numCols();
            for (int i = 0; i < this.layers.size(); i++) {
                if (this.layers.get(i).isInPlace()) {
                    arrayList.add(arrayList.get(i - 1));
                } else {
                    int outputSize = this.layers.get(i).getOutputSize(numCols);
                    arrayList.add(new DenseMatrix(numRows, outputSize));
                    numCols = outputSize;
                }
            }
        }
        this.layerModels.get(0).eval(denseMatrix, (DenseMatrix) arrayList.get(0));
        int size = z ? this.layers.size() : this.layers.size() - 1;
        for (int i2 = 1; i2 < size; i2++) {
            this.layerModels.get(i2).eval((DenseMatrix) arrayList.get(i2 - 1), (DenseMatrix) arrayList.get(i2));
        }
        return arrayList;
    }

    @Override // com.alibaba.alink.operator.common.classification.ann.TopologyModel
    public DenseVector predict(DenseVector denseVector) {
        List<DenseMatrix> forward = forward(new DenseMatrix(1, denseVector.size(), (double[]) denseVector.getData().clone()), true);
        return new DenseVector((double[]) forward.get(forward.size() - 1).getData().clone());
    }

    @Override // com.alibaba.alink.operator.common.classification.ann.TopologyModel
    public double computeGradient(DenseMatrix denseMatrix, DenseMatrix denseMatrix2, DenseVector denseVector) {
        List<DenseMatrix> forward = forward(denseMatrix, true);
        int numRows = denseMatrix.numRows();
        if (this.deltas == null || this.deltas.get(0).numRows() != numRows) {
            this.deltas = new ArrayList(this.layers.size() - 1);
            int numCols = denseMatrix.numCols();
            for (int i = 0; i < this.layers.size() - 1; i++) {
                int outputSize = this.layers.get(i).getOutputSize(numCols);
                this.deltas.add(new DenseMatrix(numRows, outputSize));
                numCols = outputSize;
            }
        }
        int size = this.layerModels.size() - 1;
        if (!(this.layerModels.get(size) instanceof AnnLossFunction)) {
            throw new AkUnsupportedOperationException("The last layer should be loss function");
        }
        double loss = ((AnnLossFunction) this.layerModels.get(size)).loss(forward.get(size), denseMatrix2, this.deltas.get(size - 1));
        if (denseVector == null) {
            return loss;
        }
        for (int i2 = size - 1; i2 >= 1; i2--) {
            this.layerModels.get(i2).computePrevDelta(this.deltas.get(i2), forward.get(i2), this.deltas.get(i2 - 1));
        }
        int i3 = 0;
        int i4 = 0;
        while (i4 < this.layerModels.size()) {
            DenseMatrix denseMatrix3 = i4 == 0 ? denseMatrix : forward.get(i4 - 1);
            if (i4 == this.layerModels.size() - 1) {
                this.layerModels.get(i4).grad(null, denseMatrix3, denseVector, i3);
            } else {
                this.layerModels.get(i4).grad(this.deltas.get(i4), denseMatrix3, denseVector, i3);
            }
            i3 += this.layers.get(i4).getWeightSize();
            i4++;
        }
        return loss;
    }
}
