package com.alibaba.alink.operator.common.tree;

import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.exceptions.AkIllegalStateException;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.tree.TreeUtil;
import java.io.Serializable;
import java.util.Arrays;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/Criteria.class */
public abstract class Criteria implements Cloneable, Serializable {
    public static final double INVALID_GAIN = 0.0d;
    public static final double EPS = 1.0E-15d;
    private static final long serialVersionUID = -4855396890380900139L;
    protected double weightSum;
    protected int numInstances;

    /* renamed from: com.alibaba.alink.operator.common.tree.Criteria$1, reason: invalid class name */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/Criteria$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$alibaba$alink$operator$common$tree$TreeUtil$TreeType = new int[TreeUtil.TreeType.values().length];

        static {
            try {
                $SwitchMap$com$alibaba$alink$operator$common$tree$TreeUtil$TreeType[TreeUtil.TreeType.MSE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$alibaba$alink$operator$common$tree$TreeUtil$TreeType[TreeUtil.TreeType.AVG.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$alibaba$alink$operator$common$tree$TreeUtil$TreeType[TreeUtil.TreeType.PARTITION.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$com$alibaba$alink$operator$common$tree$TreeUtil$TreeType[TreeUtil.TreeType.GINI.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$com$alibaba$alink$operator$common$tree$TreeUtil$TreeType[TreeUtil.TreeType.INFOGAIN.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$com$alibaba$alink$operator$common$tree$TreeUtil$TreeType[TreeUtil.TreeType.INFOGAINRATIO.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/Criteria$ClassificationCriteria.class */
    public static abstract class ClassificationCriteria extends Criteria {
        private static final long serialVersionUID = -631328604371947566L;
        double[] distributions;

        ClassificationCriteria(double d, int i, double[] dArr) {
            super(d, i);
            this.distributions = dArr;
        }

        public void add(int i, double d, int i2) {
            double[] dArr = this.distributions;
            dArr[i] = dArr[i] + d;
            this.weightSum += d;
            this.numInstances += i2;
        }

        public void subtract(int i, double d, int i2) {
            double[] dArr = this.distributions;
            dArr[i] = dArr[i] - d;
            this.weightSum -= d;
            this.numInstances -= i2;
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria
        public ClassificationCriteria add(Criteria criteria) {
            ClassificationCriteria classificationCriteria = (ClassificationCriteria) criteria;
            for (int i = 0; i < this.distributions.length; i++) {
                double[] dArr = this.distributions;
                int i2 = i;
                dArr[i2] = dArr[i2] + classificationCriteria.distributions[i];
            }
            this.weightSum += classificationCriteria.weightSum;
            this.numInstances += classificationCriteria.numInstances;
            return this;
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria
        public Criteria subtract(Criteria criteria) {
            ClassificationCriteria classificationCriteria = (ClassificationCriteria) criteria;
            for (int i = 0; i < this.distributions.length; i++) {
                double[] dArr = this.distributions;
                int i2 = i;
                dArr[i2] = dArr[i2] - classificationCriteria.distributions[i];
            }
            this.weightSum -= classificationCriteria.weightSum;
            this.numInstances -= classificationCriteria.numInstances;
            return this;
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria
        /* renamed from: clone, reason: merged with bridge method [inline-methods] */
        public ClassificationCriteria mo606clone() {
            ClassificationCriteria classificationCriteria = (ClassificationCriteria) super.mo606clone();
            classificationCriteria.distributions = (double[]) this.distributions.clone();
            classificationCriteria.weightSum = this.weightSum;
            classificationCriteria.numInstances = this.numInstances;
            return classificationCriteria;
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria
        public LabelCounter toLabelCounter() {
            return new LabelCounter(this.weightSum, this.numInstances, this.distributions);
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria
        public void reset() {
            Arrays.fill(this.distributions, Criteria.INVALID_GAIN);
            this.weightSum = Criteria.INVALID_GAIN;
            this.numInstances = 0;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/Criteria$Entropy.class */
    public static abstract class Entropy extends ClassificationCriteria {
        private static final double LOG2 = Math.log(2.0d);
        private static final long serialVersionUID = 7602253844112062279L;

        Entropy(double d, int i, double[] dArr) {
            super(d, i, dArr);
        }

        static double log2(double d) {
            return d == Criteria.INVALID_GAIN ? Criteria.INVALID_GAIN : Math.log(d) / LOG2;
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria
        public double impurity() {
            if (this.weightSum < 1.0E-15d) {
                return Criteria.INVALID_GAIN;
            }
            double d = 0.0d;
            for (int i = 0; i < this.distributions.length; i++) {
                double d2 = this.distributions[i] / this.weightSum;
                d += d2 * log2(d2);
            }
            return (-1.0d) * d;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/Criteria$Gain.class */
    public enum Gain {
        GINI,
        INFOGAIN,
        INFOGAINRATIO,
        MSE;

        public static final ParamInfo<Gain> GAIN = ParamInfoFactory.createParamInfo("gain", Gain.class).build();
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/Criteria$Gini.class */
    public static class Gini extends ClassificationCriteria {
        private static final long serialVersionUID = 8996209222867178997L;

        public Gini(double d, int i, double[] dArr) {
            super(d, i, dArr);
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria
        public double impurity() {
            if (this.weightSum < 1.0E-15d) {
                return Criteria.INVALID_GAIN;
            }
            double d = 0.0d;
            for (double d2 : this.distributions) {
                double d3 = d2 / this.weightSum;
                d += d3 * d3;
            }
            return 1.0d - d;
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria
        public double gain(Criteria... criteriaArr) {
            if (this.weightSum < 1.0E-15d) {
                return Criteria.INVALID_GAIN;
            }
            double impurity = impurity();
            for (Criteria criteria : criteriaArr) {
                impurity -= (criteria.weightSum / this.weightSum) * criteria.impurity();
            }
            return impurity;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/Criteria$InfoGain.class */
    public static class InfoGain extends Entropy {
        private static final long serialVersionUID = 5185893562589077621L;

        public InfoGain(double d, int i, double[] dArr) {
            super(d, i, dArr);
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria
        public double gain(Criteria... criteriaArr) {
            if (this.weightSum < 1.0E-15d) {
                return Criteria.INVALID_GAIN;
            }
            double impurity = impurity();
            for (Criteria criteria : criteriaArr) {
                impurity -= (criteria.weightSum / this.weightSum) * criteria.impurity();
            }
            return impurity;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/Criteria$InfoGainRatio.class */
    public static class InfoGainRatio extends Entropy {
        private static final long serialVersionUID = 2010844081408314373L;

        public InfoGainRatio(double d, int i, double[] dArr) {
            super(d, i, dArr);
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria
        public double gain(Criteria... criteriaArr) {
            if (this.weightSum < 1.0E-15d) {
                return Criteria.INVALID_GAIN;
            }
            double impurity = impurity();
            double d = 0.0d;
            for (Criteria criteria : criteriaArr) {
                double d2 = criteria.weightSum / this.weightSum;
                impurity -= d2 * criteria.impurity();
                d -= d2 * log2(d2);
            }
            return d < 1.0E-15d ? Criteria.INVALID_GAIN : impurity / d;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/Criteria$MSE.class */
    public static class MSE extends RegressionCriteria {
        private static final long serialVersionUID = 8895470577519000835L;
        double sum;
        double squareSum;

        public MSE(double d, int i, double d2, double d3) {
            super(d, i);
            this.sum = d2;
            this.squareSum = d3;
        }

        public double getSum() {
            return this.sum;
        }

        public void setSum(double d) {
            this.sum = d;
        }

        public double getSquareSum() {
            return this.squareSum;
        }

        public void setSquareSum(double d) {
            this.squareSum = d;
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria.RegressionCriteria
        public void add(double d, double d2, int i) {
            double d3 = d * d2;
            this.sum += d3;
            this.squareSum += d3 * d3;
            this.weightSum += d2;
            this.numInstances += i;
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria.RegressionCriteria
        public void subtract(double d, double d2, int i) {
            double d3 = d * d2;
            this.sum -= d3;
            this.squareSum -= d3 * d3;
            this.weightSum -= d2;
            this.numInstances -= i;
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria
        public MSE add(Criteria criteria) {
            MSE mse = (MSE) criteria;
            this.sum += mse.sum;
            this.squareSum += mse.squareSum;
            this.weightSum += mse.weightSum;
            this.numInstances += mse.numInstances;
            return this;
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria
        public Criteria subtract(Criteria criteria) {
            MSE mse = (MSE) criteria;
            this.sum -= mse.sum;
            this.squareSum -= mse.squareSum;
            this.weightSum -= mse.weightSum;
            this.numInstances -= mse.numInstances;
            return this;
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria
        public void reset() {
            this.sum = Criteria.INVALID_GAIN;
            this.squareSum = Criteria.INVALID_GAIN;
            this.weightSum = Criteria.INVALID_GAIN;
            this.numInstances = 0;
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria
        public LabelCounter toLabelCounter() {
            return new LabelCounter(this.weightSum, this.numInstances, new double[]{this.sum, this.squareSum});
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria
        public double impurity() {
            if (getWeightSum() < 1.0E-15d) {
                return Criteria.INVALID_GAIN;
            }
            double d = this.sum / this.weightSum;
            return (this.squareSum / this.weightSum) - (d * d);
        }

        @Override // com.alibaba.alink.operator.common.tree.Criteria
        public double gain(Criteria... criteriaArr) {
            if (this.weightSum < 1.0E-15d) {
                return Criteria.INVALID_GAIN;
            }
            double impurity = impurity();
            for (Criteria criteria : criteriaArr) {
                impurity -= (criteria.weightSum / this.weightSum) * criteria.impurity();
            }
            return impurity;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/Criteria$RegressionCriteria.class */
    public static abstract class RegressionCriteria extends Criteria {
        private static final long serialVersionUID = -2896832363044118789L;

        public RegressionCriteria(double d, int i) {
            super(d, i);
        }

        public abstract void add(double d, double d2, int i);

        public abstract void subtract(double d, double d2, int i);

        @Override // com.alibaba.alink.operator.common.tree.Criteria
        /* renamed from: clone */
        public /* bridge */ /* synthetic */ Object mo606clone() throws CloneNotSupportedException {
            return super.mo606clone();
        }
    }

    public Criteria(double d, int i) {
        this.weightSum = d;
        this.numInstances = i;
    }

    public abstract LabelCounter toLabelCounter();

    public abstract double impurity();

    public abstract double gain(Criteria... criteriaArr);

    public abstract Criteria add(Criteria criteria);

    public abstract Criteria subtract(Criteria criteria);

    public double getWeightSum() {
        return this.weightSum;
    }

    public int getNumInstances() {
        return this.numInstances;
    }

    public abstract void reset();

    @Override // 
    /* renamed from: clone */
    public Criteria mo606clone() {
        try {
            return (Criteria) super.clone();
        } catch (CloneNotSupportedException e) {
            throw new AkIllegalStateException("Can not clone the criteria.", e);
        }
    }

    public static boolean isRegression(TreeUtil.TreeType treeType) {
        switch (AnonymousClass1.$SwitchMap$com$alibaba$alink$operator$common$tree$TreeUtil$TreeType[treeType.ordinal()]) {
            case 1:
                return true;
            case 2:
            case 3:
            case 4:
            case 5:
            case TableUtil.DISPLAY_SIZE /* 6 */:
                return false;
            default:
                throw new AkIllegalArgumentException("Not support " + treeType + " yet.");
        }
    }
}
