package com.alibaba.alink.operator.common.tree.parallelcart.criteria;

import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.common.tree.LabelCounter;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/criteria/PaiCriteria.class */
public class PaiCriteria extends HessionBaseCriteria {
    private static final long serialVersionUID = -4251358962194586557L;
    double gradientSum;
    double hessionSum;
    public static final double PAI_EPS = 1.0E-6d;

    public PaiCriteria(double d, int i, double d2, double d3) {
        super(d, i);
        this.gradientSum = d2;
        this.hessionSum = d3;
    }

    @Override // com.alibaba.alink.operator.common.tree.Criteria
    public LabelCounter toLabelCounter() {
        return new LabelCounter(this.weightSum, this.numInstances, new double[]{this.gradientSum, this.hessionSum});
    }

    @Override // com.alibaba.alink.operator.common.tree.Criteria
    public double impurity() {
        return this.weightSum == Criteria.INVALID_GAIN ? Criteria.INVALID_GAIN : (this.gradientSum * this.gradientSum) / this.weightSum;
    }

    @Override // com.alibaba.alink.operator.common.tree.Criteria
    public double gain(Criteria... criteriaArr) {
        double d = 0.0d;
        for (Criteria criteria : criteriaArr) {
            d += criteria.impurity();
        }
        return d < 1.0E-6d ? Criteria.INVALID_GAIN : d;
    }

    @Override // com.alibaba.alink.operator.common.tree.Criteria
    public PaiCriteria add(Criteria criteria) {
        PaiCriteria paiCriteria = (PaiCriteria) criteria;
        this.gradientSum += paiCriteria.gradientSum;
        this.hessionSum += paiCriteria.hessionSum;
        this.weightSum += paiCriteria.weightSum;
        this.numInstances += paiCriteria.numInstances;
        return this;
    }

    @Override // com.alibaba.alink.operator.common.tree.Criteria
    public PaiCriteria subtract(Criteria criteria) {
        PaiCriteria paiCriteria = (PaiCriteria) criteria;
        this.gradientSum -= paiCriteria.gradientSum;
        this.hessionSum -= paiCriteria.hessionSum;
        this.weightSum -= paiCriteria.weightSum;
        this.numInstances -= paiCriteria.numInstances;
        return this;
    }

    @Override // com.alibaba.alink.operator.common.tree.Criteria
    public void reset() {
        this.gradientSum = Criteria.INVALID_GAIN;
        this.hessionSum = Criteria.INVALID_GAIN;
        this.weightSum = Criteria.INVALID_GAIN;
        this.numInstances = 0;
    }

    @Override // com.alibaba.alink.operator.common.tree.parallelcart.criteria.HessionBaseCriteria
    public void add(double d, double d2, double d3, int i) {
        this.gradientSum += d;
        this.hessionSum += d2;
        this.weightSum += d3;
        this.numInstances += i;
    }

    @Override // com.alibaba.alink.operator.common.tree.parallelcart.criteria.HessionBaseCriteria
    public void subtract(double d, double d2, double d3, int i) {
        this.gradientSum -= d;
        this.hessionSum -= d2;
        this.weightSum -= d3;
        this.numInstances -= i;
    }

    @Override // com.alibaba.alink.operator.common.tree.parallelcart.criteria.GBMTreeSplitCriteria
    public double actualGain(Criteria... criteriaArr) {
        double d = 0.0d;
        for (Criteria criteria : criteriaArr) {
            d += criteria.impurity();
        }
        double impurity = d - impurity();
        return impurity < 1.0E-6d ? Criteria.INVALID_GAIN : impurity;
    }
}
