package com.alibaba.alink.operator.common.tree.parallelcart.criteria;

import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.common.tree.LabelCounter;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/criteria/AlinkCriteria.class */
public class AlinkCriteria extends HessionBaseCriteria {
    private static final long serialVersionUID = 3057273388324167712L;
    double gradientSum;
    double hessionSum;

    public AlinkCriteria(double d, int i, double d2, double d3) {
        super(d, i);
        this.gradientSum = d2;
        this.hessionSum = d3;
    }

    @Override // com.alibaba.alink.operator.common.tree.Criteria
    public LabelCounter toLabelCounter() {
        return new LabelCounter(this.weightSum, this.numInstances, new double[]{this.gradientSum, this.hessionSum});
    }

    @Override // com.alibaba.alink.operator.common.tree.Criteria
    public double impurity() {
        return (this.gradientSum * this.gradientSum) / this.hessionSum;
    }

    @Override // com.alibaba.alink.operator.common.tree.Criteria
    public double gain(Criteria... criteriaArr) {
        if (this.hessionSum == Criteria.INVALID_GAIN) {
            return Criteria.INVALID_GAIN;
        }
        double impurity = impurity();
        double d = 0.0d;
        for (Criteria criteria : criteriaArr) {
            AlinkCriteria alinkCriteria = (AlinkCriteria) criteria;
            if (alinkCriteria.hessionSum == Criteria.INVALID_GAIN) {
                return Criteria.INVALID_GAIN;
            }
            d += alinkCriteria.impurity();
        }
        return Math.abs(d - impurity);
    }

    @Override // com.alibaba.alink.operator.common.tree.Criteria
    public AlinkCriteria add(Criteria criteria) {
        AlinkCriteria alinkCriteria = (AlinkCriteria) criteria;
        this.gradientSum += alinkCriteria.gradientSum;
        this.weightSum += alinkCriteria.weightSum;
        this.numInstances += alinkCriteria.numInstances;
        return this;
    }

    @Override // com.alibaba.alink.operator.common.tree.Criteria
    public AlinkCriteria subtract(Criteria criteria) {
        AlinkCriteria alinkCriteria = (AlinkCriteria) criteria;
        this.gradientSum -= alinkCriteria.gradientSum;
        this.weightSum -= alinkCriteria.weightSum;
        this.numInstances -= alinkCriteria.numInstances;
        return this;
    }

    @Override // com.alibaba.alink.operator.common.tree.Criteria
    public void reset() {
        this.gradientSum = Criteria.INVALID_GAIN;
        this.hessionSum = Criteria.INVALID_GAIN;
        this.weightSum = Criteria.INVALID_GAIN;
        this.numInstances = 0;
    }

    @Override // com.alibaba.alink.operator.common.tree.parallelcart.criteria.HessionBaseCriteria
    public void add(double d, double d2, double d3, int i) {
        this.gradientSum += d * d3;
        this.hessionSum += d2 * d3;
        this.weightSum += d3;
        this.numInstances += i;
    }

    @Override // com.alibaba.alink.operator.common.tree.parallelcart.criteria.HessionBaseCriteria
    public void subtract(double d, double d2, double d3, int i) {
        this.gradientSum -= d * d3;
        this.hessionSum -= d2 * d3;
        this.weightSum -= d3;
        this.numInstances -= i;
    }

    @Override // com.alibaba.alink.operator.common.tree.parallelcart.criteria.GBMTreeSplitCriteria
    public double actualGain(Criteria... criteriaArr) {
        return gain(criteriaArr);
    }
}
