package com.alibaba.alink.operator.common.tree.parallelcart.booster;

import com.alibaba.alink.operator.common.linear.unarylossfunc.UnaryLossFunc;
import com.alibaba.alink.operator.common.tree.parallelcart.BoostingObjs;
import com.alibaba.alink.operator.common.tree.parallelcart.data.Slice;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/booster/GradientBaseBooster.class */
public class GradientBaseBooster implements Booster {
    UnaryLossFunc loss;
    Slice slice;
    double[] gradient;
    double[] gradientSqr;
    double[] weights;

    public GradientBaseBooster(UnaryLossFunc unaryLossFunc, double[] dArr, Slice slice) {
        this.loss = unaryLossFunc;
        this.slice = slice;
        this.weights = dArr;
        this.gradient = new double[this.slice.end - this.slice.start];
        this.gradientSqr = new double[this.slice.end - this.slice.start];
    }

    @Override // com.alibaba.alink.operator.common.tree.parallelcart.booster.Booster
    public void boosting(BoostingObjs boostingObjs, double[] dArr, double[] dArr2) {
        for (int i = this.slice.start; i < this.slice.end; i++) {
            double derivative = this.loss.derivative(dArr2[i], dArr[i]);
            this.gradient[i] = derivative;
            this.gradientSqr[i] = derivative * derivative;
        }
    }

    @Override // com.alibaba.alink.operator.common.tree.parallelcart.booster.Booster
    public double[] getWeights() {
        return this.weights;
    }

    @Override // com.alibaba.alink.operator.common.tree.parallelcart.booster.Booster
    public double[] getGradients() {
        return this.gradient;
    }

    @Override // com.alibaba.alink.operator.common.tree.parallelcart.booster.Booster
    public double[] getHessions() {
        return null;
    }

    @Override // com.alibaba.alink.operator.common.tree.parallelcart.booster.Booster
    public double[] getGradientsSqr() {
        return this.gradientSqr;
    }
}
