package com.alibaba.alink.operator.common.classification.ann;

import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/classification/ann/AnnObjFunc.class */
public class AnnObjFunc extends OptimObjFunc {
    private static final long serialVersionUID = 7635533586488766373L;
    private final Topology topology;
    private final Stacker stacker;
    private transient TopologyModel topologyModel;

    public AnnObjFunc(Topology topology, int i, int i2, boolean z, Params params) {
        super(params);
        this.topologyModel = null;
        this.topology = topology;
        this.stacker = new Stacker(i, i2, z);
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public double calcLoss(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector) {
        if (this.topologyModel == null) {
            this.topologyModel = this.topology.getModel(denseVector);
        } else {
            this.topologyModel.resetModel(denseVector);
        }
        Tuple2<DenseMatrix, DenseMatrix> unstack = this.stacker.unstack(tuple3);
        return this.topologyModel.computeGradient((DenseMatrix) unstack.f0, (DenseMatrix) unstack.f1, null);
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public void updateGradient(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector, DenseVector denseVector2) {
        if (this.topologyModel == null) {
            this.topologyModel = this.topology.getModel(denseVector);
        } else {
            this.topologyModel.resetModel(denseVector);
        }
        Tuple2<DenseMatrix, DenseMatrix> unstack = this.stacker.unstack(tuple3);
        this.topologyModel.computeGradient((DenseMatrix) unstack.f0, (DenseMatrix) unstack.f1, denseVector2);
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public void updateHessian(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector, DenseMatrix denseMatrix) {
        throw new AkUnsupportedOperationException("not supported.");
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public boolean hasSecondDerivative() {
        return false;
    }
}
