package com.alibaba.alink.operator.common.regression.glm;

import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.LinearSolver;
import com.alibaba.alink.common.utils.AlinkSerializable;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.clustering.lda.LdaVariable;
import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable;
import com.alibaba.alink.operator.common.regression.GlmModelData;
import com.alibaba.alink.operator.common.regression.GlmModelDataConverter;
import com.alibaba.alink.operator.common.regression.glm.famliy.Binomial;
import com.alibaba.alink.operator.common.regression.glm.famliy.FamilyFunction;
import com.alibaba.alink.operator.common.regression.glm.famliy.Gaussian;
import com.alibaba.alink.operator.common.regression.glm.famliy.Poisson;
import com.alibaba.alink.operator.common.regression.glm.link.Identity;
import com.alibaba.alink.operator.common.regression.glm.link.LinkFunction;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.github.fommil.netlib.BLAS;
import com.github.fommil.netlib.F2jBLAS;
import com.github.fommil.netlib.LAPACK;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.commons.math3.distribution.GammaDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.PoissonDistribution;
import org.apache.commons.math3.distribution.TDistribution;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.operators.ReduceOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple5;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.netlib.util.intW;

/* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil.class */
public class GlmUtil {
    private static final double gLanczos = 4.7421875d;
    public static double EPSILON = 1.0E-16d;
    public static double DELTA = 0.1d;
    private static final double[] ckLanczos = {0.9999999999999971d, 57.15623566586292d, -59.59796035547549d, 14.136097974741746d, -0.4919138160976202d, 3.399464998481189E-5d, 4.652362892704858E-5d, -9.837447530487956E-5d, 1.580887032249125E-4d, -2.1026444172410488E-4d, 2.1743961811521265E-4d, -1.643181065367639E-4d, 8.441822398385275E-5d, -2.6190838401581408E-5d, 3.6899182659531625E-6d};

    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$AggSummary.class */
    private static class AggSummary extends RichMapFunction<Double, GlmModelSummary> {
        private static final long serialVersionUID = 1993165824137115835L;
        private boolean fitIntercept;
        private int numFeature;
        private String familyName;
        private long count;
        private Double aic;
        private double nullDeviance;
        private double deviance;
        private double dispersion;
        private WeightedLeastSquaresModel model;

        public AggSummary(boolean z, int i, String str) {
            this.fitIntercept = z;
            this.numFeature = i;
            this.familyName = str;
        }

        public void open(Configuration configuration) {
            Tuple5 tuple5 = (Tuple5) getRuntimeContext().getBroadcastVariable("deviance").get(0);
            this.count = Math.round(((Double) tuple5.f4).doubleValue());
            this.nullDeviance = ((Double) tuple5.f0).doubleValue();
            this.deviance = ((Double) tuple5.f1).doubleValue();
            this.dispersion = ((Double) tuple5.f2).doubleValue();
            this.aic = (Double) getRuntimeContext().getBroadcastVariable("aic").get(0);
            this.model = (WeightedLeastSquaresModel) getRuntimeContext().getBroadcastVariable(OptimVariable.model).get(0);
        }

        public GlmModelSummary map(Double d) throws Exception {
            GlmModelSummary glmModelSummary = new GlmModelSummary();
            glmModelSummary.rank = rank();
            glmModelSummary.degreeOfFreedom = degreeOfFreedom();
            glmModelSummary.residualDegreeOfFreeDom = residualDegreeOfFreeDom();
            glmModelSummary.residualDegreeOfFreedomNull = residualDegreeOfFreedomNull();
            glmModelSummary.aic = aic();
            glmModelSummary.dispersion = dispersion();
            glmModelSummary.deviance = deviance();
            glmModelSummary.nullDeviance = nullDeviance();
            glmModelSummary.coefficients = this.model.coefficients;
            glmModelSummary.intercept = this.model.intercept;
            glmModelSummary.coefficientStandardErrors = coefficientStandardErrors();
            glmModelSummary.tValues = tValues();
            glmModelSummary.pValues = pValues();
            return glmModelSummary;
        }

        public int rank() {
            return this.fitIntercept ? this.numFeature + 1 : this.numFeature;
        }

        public long degreeOfFreedom() {
            return this.count - rank();
        }

        public long residualDegreeOfFreeDom() {
            return degreeOfFreedom();
        }

        public long residualDegreeOfFreedomNull() {
            return this.fitIntercept ? this.count - 1 : this.count;
        }

        public double aic() {
            if (this.aic == null) {
                return Double.MIN_VALUE;
            }
            return this.aic.doubleValue() + (2 * rank());
        }

        public double dispersion() {
            if (this.familyName.equals(new Binomial().name()) || this.familyName.equals(new Poisson().name())) {
                return 1.0d;
            }
            return this.dispersion / degreeOfFreedom();
        }

        public double deviance() {
            return this.deviance;
        }

        public double nullDeviance() {
            if (Double.isNaN(this.nullDeviance)) {
                return Double.MIN_VALUE;
            }
            return this.nullDeviance;
        }

        public double[] coefficientStandardErrors() {
            double[] dArr = new double[this.model.diagInvAtWA.length];
            double dispersion = dispersion();
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = Math.sqrt(this.model.diagInvAtWA[i] * dispersion);
            }
            return dArr;
        }

        public double[] tValues() {
            double[] dArr = new double[this.model.diagInvAtWA.length];
            double[] coefficientStandardErrors = coefficientStandardErrors();
            for (int i = 0; i < this.numFeature; i++) {
                dArr[i] = this.model.coefficients[i] / coefficientStandardErrors[i];
            }
            if (this.fitIntercept) {
                dArr[this.numFeature] = this.model.intercept / coefficientStandardErrors[this.numFeature];
            }
            return dArr;
        }

        public double[] pValues() {
            double[] tValues = tValues();
            double[] dArr = new double[tValues.length];
            if (this.familyName.equals(new Binomial().name()) || this.familyName.equals(new Poisson().name())) {
                NormalDistribution normalDistribution = new NormalDistribution();
                for (int i = 0; i < tValues.length; i++) {
                    dArr[i] = 2.0d * (1.0d - normalDistribution.cumulativeProbability(Math.abs(tValues[i])));
                }
            } else {
                TDistribution tDistribution = new TDistribution(degreeOfFreedom());
                for (int i2 = 0; i2 < tValues.length; i2++) {
                    dArr[i2] = 2.0d * (1.0d - tDistribution.cumulativeProbability(Math.abs(tValues[i2])));
                }
            }
            return dArr;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$BinomialAicTransform.class */
    public static class BinomialAicTransform implements MapFunction<Row, Double> {
        private static final long serialVersionUID = 2795419799172034855L;
        private int numFeature;

        BinomialAicTransform(int i) {
            this.numFeature = i;
        }

        public Double map(Row row) throws Exception {
            double doubleValue = ((Double) row.getField(this.numFeature)).doubleValue();
            double doubleValue2 = ((Double) row.getField(this.numFeature + 1)).doubleValue();
            double doubleValue3 = ((Double) row.getField(this.numFeature + 3)).doubleValue();
            int round = (int) Math.round(doubleValue2);
            return round == 0 ? Double.valueOf(Criteria.INVALID_GAIN) : Double.valueOf(logProbability(round, doubleValue3, (int) Math.round(doubleValue * doubleValue2)));
        }

        private double logI(Boolean bool) {
            if (bool.booleanValue()) {
                return Criteria.INVALID_GAIN;
            }
            return Double.NEGATIVE_INFINITY;
        }

        private double logProbability(int i, double d, int i2) {
            if (d == Criteria.INVALID_GAIN) {
                return logI(Boolean.valueOf(i2 == 0));
            }
            if (d == 1.0d) {
                return logI(Boolean.valueOf(i2 == i));
            }
            return ((GlmUtil.lnGamma(i + 1) - GlmUtil.lnGamma(i2 + 1)) - GlmUtil.lnGamma((i - i2) + 1)) + (i2 * Math.log(d)) + ((i - i2) * Math.log(1.0d - d));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$GammaAicTransform.class */
    public static class GammaAicTransform extends RichMapFunction<Row, Double> {
        private static final long serialVersionUID = 4869720011021776554L;
        private int numFeature;
        private double disp;

        public GammaAicTransform(int i) {
            this.numFeature = i;
        }

        public void open(Configuration configuration) {
            Tuple5 tuple5 = (Tuple5) getRuntimeContext().getBroadcastVariable("deviance").get(0);
            this.disp = ((Double) tuple5.f1).doubleValue() / ((Double) tuple5.f3).doubleValue();
        }

        public Double map(Row row) throws Exception {
            double doubleValue = ((Double) row.getField(this.numFeature)).doubleValue();
            return Double.valueOf(((Double) row.getField(this.numFeature + 1)).doubleValue() * Math.log(new GammaDistribution(1.0d / this.disp, 1.0d / (((Double) row.getField(this.numFeature + 3)).doubleValue() * this.disp)).density(doubleValue)));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$GaussianDataProc.class */
    public static class GaussianDataProc implements MapFunction<Row, Row> {
        private static final long serialVersionUID = 541391719390519905L;
        private int numFeature;

        public GaussianDataProc(int i) {
            this.numFeature = i;
        }

        public Row map(Row row) {
            Row row2 = new Row(row.getArity());
            for (int i = 0; i < this.numFeature; i++) {
                row2.setField(i, row.getField(i));
            }
            double doubleValue = ((Double) row.getField(this.numFeature)).doubleValue();
            double doubleValue2 = ((Double) row.getField(this.numFeature + 1)).doubleValue();
            row2.setField(this.numFeature, Double.valueOf(doubleValue - ((Double) row.getField(this.numFeature + 2)).doubleValue()));
            row2.setField(this.numFeature + 1, Double.valueOf(doubleValue2));
            row2.setField(this.numFeature + 2, Double.valueOf(Criteria.INVALID_GAIN));
            return row2;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$GaussianIntercept.class */
    private static class GaussianIntercept implements MapPartitionFunction<Tuple2<Double, Double>, Double> {
        private static final long serialVersionUID = 7012869802526966852L;

        private GaussianIntercept() {
        }

        public void mapPartition(Iterable<Tuple2<Double, Double>> iterable, Collector<Double> collector) throws Exception {
            for (Tuple2<Double, Double> tuple2 : iterable) {
                collector.collect(Double.valueOf(((Double) tuple2.f0).doubleValue() / ((Double) tuple2.f1).doubleValue()));
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$GaussianInterceptMap.class */
    private static class GaussianInterceptMap implements MapFunction<Row, Tuple2<Double, Double>> {
        private static final long serialVersionUID = 4683689101022028728L;
        private int numFeature;

        GaussianInterceptMap(int i) {
            this.numFeature = i;
        }

        public Tuple2<Double, Double> map(Row row) {
            double doubleValue = ((Double) row.getField(this.numFeature)).doubleValue();
            double doubleValue2 = ((Double) row.getField(this.numFeature + 1)).doubleValue();
            return new Tuple2<>(Double.valueOf(doubleValue2 * (doubleValue - ((Double) row.getField(this.numFeature + 2)).doubleValue())), Double.valueOf(doubleValue2));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$GaussionAicTransform1.class */
    public static class GaussionAicTransform1 extends RichMapFunction<Row, Double> {
        private static final long serialVersionUID = -5244548914921192186L;
        private int numFeature;

        public GaussionAicTransform1(int i) {
            this.numFeature = i;
        }

        public Double map(Row row) {
            return Double.valueOf(Math.log(((Double) row.getField(this.numFeature + 1)).doubleValue()));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$GaussionAicTransform2.class */
    public static class GaussionAicTransform2 extends RichMapFunction<Double, Double> {
        private static final long serialVersionUID = 1000961790683281069L;
        private double deviance;
        private double weightSum;
        private double count;

        public void open(Configuration configuration) {
            Tuple5 tuple5 = (Tuple5) getRuntimeContext().getBroadcastVariable("deviance").get(0);
            this.deviance = ((Double) tuple5.f1).doubleValue();
            this.weightSum = ((Double) tuple5.f3).doubleValue();
            this.count = ((Double) tuple5.f4).doubleValue();
        }

        public Double map(Double d) {
            return Double.valueOf(((this.count * (Math.log(((this.deviance / this.count) * 2.0d) * 3.141592653589793d) + 1.0d)) + 2.0d) - d.doubleValue());
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$GlmModelSummary.class */
    public static class GlmModelSummary implements Serializable, AlinkSerializable {
        private static final long serialVersionUID = -809281767159062744L;
        public int rank;
        public long degreeOfFreedom;
        public long residualDegreeOfFreeDom;
        public long residualDegreeOfFreedomNull;
        public double aic;
        public double dispersion;
        public double deviance;
        public double nullDeviance;
        public double[] coefficients;
        public double intercept;
        public double[] coefficientStandardErrors;
        public double[] tValues;
        public double[] pValues;
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$GlmModelToWlsModel.class */
    public static class GlmModelToWlsModel implements MapPartitionFunction<Row, WeightedLeastSquaresModel> {
        private static final long serialVersionUID = -5190469991236996289L;

        public void mapPartition(Iterable<Row> iterable, Collector<WeightedLeastSquaresModel> collector) {
            ArrayList arrayList = new ArrayList();
            Iterator<Row> it = iterable.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next());
            }
            GlmModelData load = new GlmModelDataConverter().load(arrayList);
            WeightedLeastSquaresModel weightedLeastSquaresModel = new WeightedLeastSquaresModel();
            weightedLeastSquaresModel.coefficients = load.coefficients;
            weightedLeastSquaresModel.diagInvAtWA = load.diagInvAtWA;
            weightedLeastSquaresModel.fitIntercept = load.fitIntercept;
            weightedLeastSquaresModel.intercept = load.intercept;
            weightedLeastSquaresModel.numInstances = 0L;
            collector.collect(weightedLeastSquaresModel);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$GlobalSum.class */
    public static class GlobalSum implements ReduceFunction<Double> {
        private static final long serialVersionUID = 5958862729047735260L;

        private GlobalSum() {
        }

        public Double reduce(Double d, Double d2) throws Exception {
            return Double.valueOf(d.doubleValue() + d2.doubleValue());
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$GlobalSum2.class */
    private static class GlobalSum2 implements ReduceFunction<Tuple2<Double, Double>> {
        private static final long serialVersionUID = 9038160815514679752L;

        private GlobalSum2() {
        }

        public Tuple2<Double, Double> reduce(Tuple2<Double, Double> tuple2, Tuple2<Double, Double> tuple22) throws Exception {
            return new Tuple2<>(Double.valueOf(((Double) tuple2.f0).doubleValue() + ((Double) tuple22.f0).doubleValue()), Double.valueOf(((Double) tuple2.f1).doubleValue() + ((Double) tuple22.f1).doubleValue()));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$GlobalSum5.class */
    private static class GlobalSum5 implements ReduceFunction<Tuple5<Double, Double, Double, Double, Double>> {
        private static final long serialVersionUID = 961744826525651057L;

        private GlobalSum5() {
        }

        public Tuple5<Double, Double, Double, Double, Double> reduce(Tuple5<Double, Double, Double, Double, Double> tuple5, Tuple5<Double, Double, Double, Double, Double> tuple52) {
            return new Tuple5<>(Double.valueOf(((Double) tuple5.f0).doubleValue() + ((Double) tuple52.f0).doubleValue()), Double.valueOf(((Double) tuple5.f1).doubleValue() + ((Double) tuple52.f1).doubleValue()), Double.valueOf(((Double) tuple5.f2).doubleValue() + ((Double) tuple52.f2).doubleValue()), Double.valueOf(((Double) tuple5.f3).doubleValue() + ((Double) tuple52.f3).doubleValue()), Double.valueOf(((Double) tuple5.f4).doubleValue() + ((Double) tuple52.f4).doubleValue()));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$GlobalWeightStat.class */
    public static class GlobalWeightStat implements ReduceFunction<WeightStat> {
        private static final long serialVersionUID = 2180799423034826908L;

        private GlobalWeightStat() {
        }

        public WeightStat reduce(WeightStat weightStat, WeightStat weightStat2) throws Exception {
            return WeightStat.merge(weightStat, weightStat2);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$InitData.class */
    public static class InitData extends RichMapFunction<Row, Row> {
        private static final long serialVersionUID = 3777096603627082951L;
        private FamilyLink familyLink;
        private int numFeature;

        InitData(FamilyLink familyLink, int i) {
            this.familyLink = familyLink;
            this.numFeature = i;
        }

        public Row map(Row row) throws Exception {
            double doubleValue = ((Double) row.getField(this.numFeature)).doubleValue();
            double doubleValue2 = ((Double) row.getField(this.numFeature + 1)).doubleValue();
            double predict = this.familyLink.predict(this.familyLink.getFamily().initialize(doubleValue, doubleValue2)) - ((Double) row.getField(this.numFeature + 2)).doubleValue();
            Row copy = Row.copy(row);
            copy.setField(this.numFeature, Double.valueOf(predict));
            return copy;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$IterCriterion.class */
    public static class IterCriterion implements FilterFunction<Tuple2<WeightedLeastSquaresModel, WeightedLeastSquaresModel>> {
        private static final long serialVersionUID = 3456664362450700037L;
        private double epsilon;

        private IterCriterion(double d) {
            this.epsilon = d;
        }

        public boolean filter(Tuple2<WeightedLeastSquaresModel, WeightedLeastSquaresModel> tuple2) {
            WeightedLeastSquaresModel weightedLeastSquaresModel = (WeightedLeastSquaresModel) tuple2.f0;
            WeightedLeastSquaresModel weightedLeastSquaresModel2 = (WeightedLeastSquaresModel) tuple2.f1;
            double abs = Math.abs(weightedLeastSquaresModel2.intercept - weightedLeastSquaresModel.intercept);
            for (int i = 0; i < weightedLeastSquaresModel2.coefficients.length; i++) {
                double abs2 = Math.abs(weightedLeastSquaresModel2.coefficients[i] - weightedLeastSquaresModel.coefficients[i]);
                if (abs2 > abs) {
                    abs = abs2;
                }
            }
            return abs > this.epsilon;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$LocalWeightStat.class */
    public static class LocalWeightStat extends RichMapPartitionFunction<Row, WeightStat> {
        private static final long serialVersionUID = -6986101578118695071L;
        private int featureSize;
        private double[] features;

        public LocalWeightStat(int i) {
            this.featureSize = i;
            this.features = new double[i];
        }

        public void mapPartition(Iterable<Row> iterable, Collector<WeightStat> collector) throws Exception {
            WeightStat weightStat = new WeightStat(this.featureSize);
            for (Row row : iterable) {
                for (int i = 0; i < this.featureSize; i++) {
                    this.features[i] = ((Number) row.getField(i)).doubleValue();
                }
                weightStat.add(this.features, ((Double) row.getField(this.featureSize)).doubleValue(), ((Double) row.getField(this.featureSize + 1)).doubleValue());
            }
            collector.collect(weightStat);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$ModelAddId.class */
    public static class ModelAddId implements MapFunction<WeightedLeastSquaresModel, Tuple2<Integer, WeightedLeastSquaresModel>> {
        private static final long serialVersionUID = 1017625416129768915L;

        private ModelAddId() {
        }

        public Tuple2<Integer, WeightedLeastSquaresModel> map(WeightedLeastSquaresModel weightedLeastSquaresModel) throws Exception {
            return new Tuple2<>(0, weightedLeastSquaresModel);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$NullProcMapFunc.class */
    private static class NullProcMapFunc implements MapFunction<Row, Row> {
        private static final long serialVersionUID = 2049606496916252036L;
        private int numFeature;

        NullProcMapFunc(int i) {
            this.numFeature = i;
        }

        public Row map(Row row) throws Exception {
            Row row2 = new Row(3);
            row2.setField(0, row.getField(this.numFeature));
            row2.setField(1, row.getField(this.numFeature + 1));
            row2.setField(2, row.getField(this.numFeature + 2));
            return row2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$PossionAicTransform.class */
    public static class PossionAicTransform extends RichMapFunction<Row, Double> {
        private static final long serialVersionUID = -8301333372535179746L;
        private int numFeature;
        private double disp;

        public PossionAicTransform(int i) {
            this.numFeature = i;
        }

        public Double map(Row row) throws Exception {
            return Double.valueOf(((Double) row.getField(this.numFeature + 1)).doubleValue() * Math.log(new PoissonDistribution(((Double) row.getField(this.numFeature + 3)).doubleValue()).probability((int) ((Double) row.getField(this.numFeature)).doubleValue())));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$PreProcMapFunc.class */
    private static class PreProcMapFunc implements MapFunction<Row, Row> {
        private static final long serialVersionUID = -1493915848795117611L;
        private int[] featureColIdxs;
        private int labelColIdx;
        private int weightColIdx;
        private int offsetColIdx;

        PreProcMapFunc(int[] iArr, int i, int i2, int i3) {
            this.featureColIdxs = iArr;
            this.labelColIdx = i;
            this.weightColIdx = i2;
            this.offsetColIdx = i3;
        }

        public Row map(Row row) throws Exception {
            return GlmUtil.preProcRow(row, this.featureColIdxs, this.labelColIdx, this.weightColIdx, this.offsetColIdx);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$Residual.class */
    private static class Residual extends RichMapFunction<Row, Row> {
        private static final long serialVersionUID = -942482393961833349L;
        private int numFeature;
        private FamilyLink familyLink;
        private double[] coefficients;
        private double intercept;

        public Residual(int i, FamilyLink familyLink) {
            this.numFeature = i;
            this.familyLink = familyLink;
        }

        public void open(Configuration configuration) {
            WeightedLeastSquaresModel weightedLeastSquaresModel = (WeightedLeastSquaresModel) getRuntimeContext().getBroadcastVariable(OptimVariable.model).get(0);
            this.coefficients = weightedLeastSquaresModel.coefficients;
            this.intercept = weightedLeastSquaresModel.intercept;
        }

        public Row map(Row row) {
            return GlmUtil.residualRow(row, this.numFeature, this.familyLink, this.coefficients, this.intercept);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$UpdateData.class */
    public static class UpdateData extends RichMapFunction<Row, Row> {
        private static final long serialVersionUID = -6963213040879546310L;
        private FamilyLink familyLink;
        private double[] features;
        private double[] coefficients;
        private double intercept;

        public UpdateData(FamilyLink familyLink, int i) {
            this.familyLink = familyLink;
            this.features = new double[i];
        }

        public void open(Configuration configuration) throws Exception {
            WeightedLeastSquaresModel weightedLeastSquaresModel = (WeightedLeastSquaresModel) getRuntimeContext().getBroadcastVariable(OptimVariable.model).get(0);
            this.coefficients = weightedLeastSquaresModel.coefficients;
            this.intercept = weightedLeastSquaresModel.intercept;
        }

        public Row map(Row row) throws Exception {
            for (int i = 0; i < this.features.length; i++) {
                this.features[i] = ((Double) row.getField(i)).doubleValue();
            }
            this.features = this.familyLink.calcWeightAndLabel(this.coefficients, this.intercept, this.features);
            Row row2 = new Row(this.features.length);
            for (int i2 = 0; i2 < this.features.length; i2++) {
                row2.setField(i2, Double.valueOf(this.features[i2]));
            }
            return row2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$WeightStat.class */
    public static class WeightStat implements Serializable, AlinkSerializable {
        private static final long serialVersionUID = -4479171990155683717L;
        private int k;
        private int triK;
        private double[] aSum;
        private double[] abSum;
        private double[] aaSum;
        private long count = 0;
        private double wSum = Criteria.INVALID_GAIN;
        private double wwSum = Criteria.INVALID_GAIN;
        private double bSum = Criteria.INVALID_GAIN;
        private double bbSum = Criteria.INVALID_GAIN;
        private BLAS blas = F2jBLAS.getInstance();

        WeightStat(int i) {
            this.k = i;
            this.triK = (i * (i + 1)) / 2;
            this.aSum = new double[i];
            this.abSum = new double[i];
            this.aaSum = new double[this.triK];
        }

        public static WeightStat merge(WeightStat weightStat, WeightStat weightStat2) {
            WeightStat weightStat3 = new WeightStat(weightStat.k);
            weightStat3.merge(weightStat);
            weightStat3.merge(weightStat2);
            return weightStat3;
        }

        public void add(double[] dArr, double d, double d2) {
            this.count++;
            this.wSum += d2;
            this.wwSum += d2 * d2;
            this.bSum += d2 * d;
            this.bbSum += d2 * d * d;
            axpy(d2, dArr, this.aSum);
            axpy(d2 * d, dArr, this.abSum);
            spr(d2, dArr, this.aaSum);
        }

        public void merge(WeightStat weightStat) {
            this.count += weightStat.count;
            this.wSum += weightStat.wSum;
            this.wwSum += weightStat.wwSum;
            this.bSum += weightStat.bSum;
            this.bbSum += weightStat.bbSum;
            axpy(1.0d, weightStat.aSum, this.aSum);
            axpy(1.0d, weightStat.abSum, this.abSum);
            axpy(1.0d, weightStat.aaSum, this.aaSum);
        }

        double[] aMean() {
            double[] dArr = (double[]) this.aSum.clone();
            scal(1.0d / this.wSum, dArr);
            return dArr;
        }

        double bMean() {
            return this.bSum / this.wSum;
        }

        double bStdDeviation() {
            return Math.sqrt(Math.max((this.bbSum / this.wSum) - (bMean() * bMean()), Criteria.INVALID_GAIN));
        }

        double[] abMean() {
            double[] dArr = (double[]) this.abSum.clone();
            scal(1.0d / this.wSum, dArr);
            return dArr;
        }

        double[] aaMean() {
            double[] dArr = (double[]) this.aaSum.clone();
            scal(1.0d / this.wSum, dArr);
            return dArr;
        }

        double[] aStdDeviation() {
            double[] aVariance = aVariance();
            for (int i = 0; i < aVariance.length; i++) {
                aVariance[i] = Math.sqrt(aVariance[i]);
            }
            return aVariance;
        }

        double[] aVariance() {
            double[] dArr = new double[this.k];
            int i = 0;
            int i2 = 2;
            while (i < this.triK) {
                int i3 = i2 - 2;
                double d = this.aSum[i3] / this.wSum;
                dArr[i3] = Math.max((this.aaSum[i] / this.wSum) - (d * d), Criteria.INVALID_GAIN);
                i += i2;
                i2++;
            }
            return dArr;
        }

        private void axpy(double d, double[] dArr, double[] dArr2) {
            this.blas.daxpy(dArr.length, d, dArr, 1, dArr2, 1);
        }

        private void spr(double d, double[] dArr, double[] dArr2) {
            this.blas.dspr("U", dArr.length, d, dArr, 1, dArr2);
        }

        private void scal(double d, double[] dArr) {
            this.blas.dscal(dArr.length, d, dArr, 1);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$WeightedLeastSquares.class */
    public static class WeightedLeastSquares extends RichMapPartitionFunction<WeightStat, WeightedLeastSquaresModel> {
        private static final long serialVersionUID = -6227052527587948870L;
        private boolean fitIntercept;
        private double regParam;
        private boolean standardizeFeatures;
        private boolean standardizeLabel;
        private LAPACK lapack = null;

        WeightedLeastSquares(boolean z, double d, boolean z2, boolean z3) {
            this.fitIntercept = z;
            this.regParam = d;
            this.standardizeFeatures = z2;
            this.standardizeLabel = z3;
        }

        public void open(Configuration configuration) {
            this.lapack = LAPACK.getInstance();
        }

        public void mapPartition(Iterable<WeightStat> iterable, Collector<WeightedLeastSquaresModel> collector) throws Exception {
            for (WeightStat weightStat : iterable) {
                int i = this.fitIntercept ? weightStat.k + 1 : weightStat.k;
                int i2 = weightStat.k;
                int i3 = weightStat.triK;
                double d = weightStat.wSum;
                double bStdDeviation = weightStat.bStdDeviation();
                double abs = bStdDeviation == Criteria.INVALID_GAIN ? Math.abs(weightStat.bMean()) : bStdDeviation;
                double bMean = weightStat.bMean() / abs;
                double[] aStdDeviation = weightStat.aStdDeviation();
                double[] aMean = weightStat.aMean();
                for (int i4 = 0; i4 < i2; i4++) {
                    if (Criteria.INVALID_GAIN == aStdDeviation[i4]) {
                        aMean[i4] = 0.0d;
                    } else {
                        int i5 = i4;
                        aMean[i5] = aMean[i5] / aStdDeviation[i4];
                    }
                }
                double[] abMean = weightStat.abMean();
                for (int i6 = 0; i6 < i2; i6++) {
                    if (Criteria.INVALID_GAIN == aStdDeviation[i6]) {
                        abMean[i6] = 0.0d;
                    } else {
                        int i7 = i6;
                        abMean[i7] = abMean[i7] / (aStdDeviation[i6] * abs);
                    }
                }
                double[] aaMean = weightStat.aaMean();
                int i8 = 0;
                for (int i9 = 0; i9 < i2; i9++) {
                    double d2 = aStdDeviation[i9];
                    for (int i10 = 0; i10 <= i9; i10++) {
                        double d3 = aStdDeviation[i10];
                        if (d2 == Criteria.INVALID_GAIN || d3 == Criteria.INVALID_GAIN) {
                            aaMean[i8] = 0.0d;
                        } else {
                            int i11 = i8;
                            aaMean[i11] = aaMean[i11] / (d3 * d2);
                        }
                        i8++;
                    }
                }
                double d4 = this.regParam / abs;
                int i12 = 0;
                int i13 = 2;
                while (i12 < i3) {
                    double d5 = d4;
                    if (!this.standardizeFeatures) {
                        double d6 = aStdDeviation[i13 - 2];
                        d5 = d6 != Criteria.INVALID_GAIN ? d5 / (d6 * d6) : 0.0d;
                    }
                    if (!this.standardizeLabel) {
                        d5 *= abs;
                    }
                    int i14 = i12;
                    aaMean[i14] = aaMean[i14] + d5;
                    i12 += i13;
                    i13++;
                }
                double[] ata = getATA(aaMean, aMean);
                double[] atb = getATB(abMean, bMean);
                double[] solve = solve(ata, atb);
                double[] inverse = inverse(ata, atb.length);
                WeightedLeastSquaresModel weightedLeastSquaresModel = new WeightedLeastSquaresModel();
                weightedLeastSquaresModel.coefficients = new double[i2];
                for (int i15 = 0; i15 < i2; i15++) {
                    weightedLeastSquaresModel.coefficients[i15] = solve[i15];
                }
                if (this.fitIntercept) {
                    weightedLeastSquaresModel.intercept = solve[i2] * abs;
                }
                for (int i16 = 0; i16 < i2; i16++) {
                    if (aStdDeviation[i16] != Criteria.INVALID_GAIN) {
                        double[] dArr = weightedLeastSquaresModel.coefficients;
                        int i17 = i16;
                        dArr[i17] = dArr[i17] * (abs / aStdDeviation[i16]);
                    } else {
                        weightedLeastSquaresModel.coefficients[i16] = 0.0d;
                    }
                }
                weightedLeastSquaresModel.diagInvAtWA = new double[i];
                int i18 = 1;
                while (i18 <= i) {
                    weightedLeastSquaresModel.diagInvAtWA[i18 - 1] = inverse[(i18 + (((i18 - 1) * i18) / 2)) - 1] / (d * ((this.fitIntercept && i18 == i) ? 1.0d : aStdDeviation[i18 - 1] * aStdDeviation[i18 - 1]));
                    i18++;
                }
                weightedLeastSquaresModel.numInstances = weightStat.count;
                weightedLeastSquaresModel.fitIntercept = this.fitIntercept;
                collector.collect(weightedLeastSquaresModel);
            }
        }

        private double[] getATA(double[] dArr, double[] dArr2) {
            if (!this.fitIntercept) {
                return (double[]) dArr.clone();
            }
            double[] dArr3 = new double[dArr.length + dArr2.length + 1];
            System.arraycopy(dArr, 0, dArr3, 0, dArr.length);
            System.arraycopy(dArr2, 0, dArr3, dArr.length, dArr2.length);
            dArr3[dArr.length + dArr2.length] = 1.0d;
            return dArr3;
        }

        private double[] getATB(double[] dArr, double d) {
            if (!this.fitIntercept) {
                return (double[]) dArr.clone();
            }
            double[] dArr2 = new double[dArr.length + 1];
            System.arraycopy(dArr, 0, dArr2, 0, dArr.length);
            dArr2[dArr.length] = d;
            return dArr2;
        }

        private double[] inverse(double[] dArr, int i) {
            intW intw = new intW(0);
            this.lapack.dpptri("U", i, dArr, intw);
            if (intw.val != 0) {
                DenseMatrix matrix = toMatrix(dArr, i);
                DenseMatrix eye = DenseMatrix.eye(i, i);
                LinearSolver.underDeterminedSolve(matrix, eye);
                int i2 = 0;
                for (int i3 = 0; i3 < i; i3++) {
                    for (int i4 = 0; i4 < i; i4++) {
                        if (i3 <= i4) {
                            dArr[i2] = eye.get(i3, i4);
                            i2++;
                        }
                    }
                }
            }
            return dArr;
        }

        private DenseMatrix inverse(DenseMatrix denseMatrix) {
            int numCols = denseMatrix.numCols();
            double[] dArr = new double[numCols * numCols];
            for (int i = 0; i < numCols; i++) {
                for (int i2 = 0; i2 < numCols; i2++) {
                    dArr[(i * numCols) + i2] = denseMatrix.get(i, i2);
                }
            }
            int[] iArr = new int[numCols + 1];
            intW intw = new intW(0);
            this.lapack.dgetrf(numCols, numCols, dArr, numCols, iArr, intw);
            int i3 = numCols * numCols;
            this.lapack.dgetri(numCols, dArr, numCols, iArr, new double[i3], i3, intw);
            return new DenseMatrix(numCols, numCols, dArr, true);
        }

        /* JADX WARN: Type inference failed for: r3v2, types: [double[], double[][]] */
        private double[] solve(double[] dArr, double[] dArr2) {
            double[] dArr3 = (double[]) dArr2.clone();
            int length = dArr2.length;
            intW intw = new intW(0);
            this.lapack.dppsv("U", length, 1, dArr, dArr3, length, intw);
            if (intw.val != 0) {
                DenseMatrix matrix = toMatrix(dArr, length);
                DenseMatrix multiplies = inverse(matrix.transpose().multiplies(matrix)).multiplies(matrix.transpose().multiplies(new DenseMatrix(new double[]{dArr2}).transpose()));
                for (int i = 0; i < length; i++) {
                    dArr3[i] = multiplies.get(i, 0);
                }
            }
            return dArr3;
        }

        private DenseMatrix toMatrix(double[] dArr, int i) {
            DenseMatrix denseMatrix = new DenseMatrix(i, i);
            int i2 = 0;
            int i3 = 0;
            int i4 = i;
            for (double d : dArr) {
                denseMatrix.set(i2, i3, d);
                denseMatrix.set(i3, i2, d);
                i3++;
                if (i3 == i) {
                    i3 = (i - i4) + 1;
                    i4--;
                    i2++;
                }
            }
            return denseMatrix;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$WeightedLeastSquaresModel.class */
    public static class WeightedLeastSquaresModel implements Serializable, AlinkSerializable {
        private static final long serialVersionUID = -7582953557426871033L;
        public double[] coefficients;
        public double intercept = Criteria.INVALID_GAIN;
        public double[] diagInvAtWA;
        public boolean fitIntercept;
        public long numInstances;
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/regression/glm/GlmUtil$summayTransform.class */
    private static class summayTransform extends RichMapFunction<Row, Tuple5<Double, Double, Double, Double, Double>> {
        private static final long serialVersionUID = 2101192321315062185L;
        private double intercept;
        private FamilyLink familyLink;
        private int numFeature;

        public summayTransform(int i, FamilyLink familyLink) {
            this.familyLink = familyLink;
            this.numFeature = i;
        }

        public void open(Configuration configuration) {
            this.intercept = ((Double) getRuntimeContext().getBroadcastVariable("intercept").get(0)).doubleValue();
        }

        public Tuple5<Double, Double, Double, Double, Double> map(Row row) {
            double doubleValue = ((Double) row.getField(this.numFeature)).doubleValue();
            double doubleValue2 = ((Double) row.getField(this.numFeature + 1)).doubleValue();
            double doubleValue3 = ((Double) row.getField(this.numFeature + 2)).doubleValue();
            double doubleValue4 = ((Double) row.getField(this.numFeature + 3)).doubleValue();
            double doubleValue5 = ((Double) row.getField(this.numFeature + 5)).doubleValue();
            double deviance = this.familyLink.getFamily().deviance(doubleValue, this.familyLink.getLink().unlink(this.intercept + doubleValue3), doubleValue2);
            double deviance2 = this.familyLink.getFamily().deviance(doubleValue, doubleValue4, doubleValue2);
            double d = 1.0d;
            if (!this.familyLink.getFamilyName().equals("Binomial") && !this.familyLink.getFamilyName().equals("Poisson")) {
                d = doubleValue5 * doubleValue5;
            }
            return new Tuple5<>(Double.valueOf(deviance), Double.valueOf(deviance2), Double.valueOf(d), Double.valueOf(doubleValue2), Double.valueOf(1.0d));
        }
    }

    public static double linearPredict(double[] dArr, double d, double[] dArr2) {
        double d2 = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d2 += dArr[i] * dArr2[i];
        }
        return d2 + d;
    }

    public static double predict(double[] dArr, double d, double[] dArr2, double d2, FamilyLink familyLink) {
        return familyLink.fitted(linearPredict(dArr, d, dArr2) + d2);
    }

    public static DataSet<Row> preProc(BatchOperator batchOperator, String[] strArr, String str, String str2, String str3) {
        if (strArr == null || strArr.length == 0) {
            throw new AkIllegalOperatorParameterException("featureColNames must be set.");
        }
        int length = strArr.length;
        int[] iArr = new int[length];
        for (int i = 0; i < length; i++) {
            iArr[i] = TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), strArr[i]);
        }
        if (str3 == null) {
            throw new AkIllegalOperatorParameterException("labelColName must be set.");
        }
        int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str3);
        int i2 = -1;
        if (str2 != null && !str2.isEmpty()) {
            i2 = TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str2);
        }
        int i3 = -1;
        if (str != null && !str.isEmpty()) {
            i3 = TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str);
        }
        return batchOperator.getDataSet().map(new PreProcMapFunc(iArr, findColIndexWithAssertAndHint, i2, i3)).name("PreProcMapFunc");
    }

    public static DataSet<WeightedLeastSquaresModel> train(DataSet<Row> dataSet, int i, FamilyLink familyLink, double d, boolean z, int i2, double d2) {
        Operator closeWith;
        FamilyFunction family = familyLink.getFamily();
        LinkFunction link = familyLink.getLink();
        if (family.name().equals(new Gaussian().name()) && link.name().equals(new Identity().name())) {
            closeWith = dataSet.map(new GaussianDataProc(i)).mapPartition(new LocalWeightStat(i)).name("init LocalWeightStat").reduce(new GlobalWeightStat()).name("init GlobalWeightStat").mapPartition(new WeightedLeastSquares(z, d, true, true)).setParallelism(1).name("init WeightedLeastSquares");
        } else {
            IterativeDataSet name = dataSet.map(new InitData(familyLink, i)).mapPartition(new LocalWeightStat(i)).name("init LocalWeightStat").reduce(new GlobalWeightStat()).name("init GlobalWeightStat").mapPartition(new WeightedLeastSquares(z, d, true, true)).setParallelism(1).name("init WeightedLeastSquares").iterate(i2).name("loop");
            Operator name2 = dataSet.map(new UpdateData(familyLink, i + 3)).name("UpdateData").withBroadcastSet(name, OptimVariable.model).mapPartition(new LocalWeightStat(i)).name("localWeightStat").reduce(new GlobalWeightStat()).name("GlobalWeightStat").mapPartition(new WeightedLeastSquares(z, d, false, false)).setParallelism(1).name("WLS");
            closeWith = name.closeWith(name2, name.map(new ModelAddId()).join(name2.map(new ModelAddId())).where(new int[]{0}).equalTo(new int[]{0}).projectFirst(new int[]{1}).projectSecond(new int[]{1}).filter(new IterCriterion(d2)));
        }
        return closeWith;
    }

    public static DataSet<Row> residual(DataSet<WeightedLeastSquaresModel> dataSet, DataSet<Row> dataSet2, int i, FamilyLink familyLink) {
        return dataSet2.map(new Residual(i, familyLink)).withBroadcastSet(dataSet, OptimVariable.model);
    }

    public static DataSet<GlmModelSummary> aggSummary(DataSet<Row> dataSet, DataSet<WeightedLeastSquaresModel> dataSet2, int i, FamilyLink familyLink, double d, int i2, double d2, boolean z) {
        ReduceOperator reduce = dataSet.map(new summayTransform(i, familyLink)).withBroadcastSet(z ? (familyLink.getFamilyName().equals("Gaussian") && familyLink.getLinkName().equals("Identity")) ? dataSet.map(new GaussianInterceptMap(i)).reduce(new GlobalSum2()).mapPartition(new GaussianIntercept()) : train(dataSet.map(new NullProcMapFunc(i)), 0, familyLink, d, true, i2, d2).mapPartition(new MapPartitionFunction<WeightedLeastSquaresModel, Double>() { // from class: com.alibaba.alink.operator.common.regression.glm.GlmUtil.1
            private static final long serialVersionUID = 5705040804798450030L;

            public void mapPartition(Iterable<WeightedLeastSquaresModel> iterable, Collector<Double> collector) {
                Iterator<WeightedLeastSquaresModel> it = iterable.iterator();
                while (it.hasNext()) {
                    collector.collect(Double.valueOf(it.next().intercept));
                }
            }
        }) : dataSet.map(new MapFunction<Row, Double>() { // from class: com.alibaba.alink.operator.common.regression.glm.GlmUtil.2
            private static final long serialVersionUID = -8549889611112875368L;

            public Double map(Row row) {
                return Double.valueOf(Criteria.INVALID_GAIN);
            }
        }).reduce(new GlobalSum()), "intercept").reduce(new GlobalSum5());
        DataSet<Double> aic = aic(dataSet, reduce, i, familyLink);
        return aic.map(new AggSummary(z, i, familyLink.getFamilyName())).withBroadcastSet(reduce, "deviance").withBroadcastSet(aic, "aic").withBroadcastSet(dataSet2, OptimVariable.model);
    }

    private static DataSet<Double> aic(DataSet<Row> dataSet, DataSet<Tuple5<Double, Double, Double, Double, Double>> dataSet2, int i, FamilyLink familyLink) {
        MapOperator withBroadcastSet;
        String familyName = familyLink.getFamilyName();
        if (familyName.equals("Tweedie")) {
            withBroadcastSet = dataSet2.map(new MapFunction<Tuple5<Double, Double, Double, Double, Double>, Double>() { // from class: com.alibaba.alink.operator.common.regression.glm.GlmUtil.3
                private static final long serialVersionUID = 218357090555877059L;

                public Double map(Tuple5<Double, Double, Double, Double, Double> tuple5) throws Exception {
                    return Double.valueOf(Double.MAX_VALUE);
                }
            });
        } else if (familyName.equals("Binomial")) {
            withBroadcastSet = dataSet.map(new BinomialAicTransform(i)).reduce(new GlobalSum()).mapPartition(new MapPartitionFunction<Double, Double>() { // from class: com.alibaba.alink.operator.common.regression.glm.GlmUtil.4
                private static final long serialVersionUID = -6753822359386684791L;

                public void mapPartition(Iterable<Double> iterable, Collector<Double> collector) {
                    Iterator<Double> it = iterable.iterator();
                    while (it.hasNext()) {
                        collector.collect(Double.valueOf((-2.0d) * it.next().doubleValue()));
                    }
                }
            });
        } else if (familyName.equals(LdaVariable.gamma)) {
            withBroadcastSet = dataSet.map(new GammaAicTransform(i)).withBroadcastSet(dataSet2, "deviance").reduce(new GlobalSum()).mapPartition(new MapPartitionFunction<Double, Double>() { // from class: com.alibaba.alink.operator.common.regression.glm.GlmUtil.5
                private static final long serialVersionUID = -3707908727123276637L;

                public void mapPartition(Iterable<Double> iterable, Collector<Double> collector) {
                    Iterator<Double> it = iterable.iterator();
                    while (it.hasNext()) {
                        collector.collect(Double.valueOf(((-2.0d) * it.next().doubleValue()) + 2.0d));
                    }
                }
            });
        } else if (familyName.equals("Poisson")) {
            withBroadcastSet = dataSet.map(new PossionAicTransform(i)).withBroadcastSet(dataSet2, "deviance").reduce(new GlobalSum()).mapPartition(new MapPartitionFunction<Double, Double>() { // from class: com.alibaba.alink.operator.common.regression.glm.GlmUtil.6
                private static final long serialVersionUID = -7690462421422742899L;

                public void mapPartition(Iterable<Double> iterable, Collector<Double> collector) {
                    Iterator<Double> it = iterable.iterator();
                    while (it.hasNext()) {
                        collector.collect(Double.valueOf((-2.0d) * it.next().doubleValue()));
                    }
                }
            });
        } else {
            if (!familyName.equals("Gaussian")) {
                throw new AkUnsupportedOperationException("family name not support yet." + familyName);
            }
            withBroadcastSet = dataSet.map(new GaussionAicTransform1(i)).reduce(new GlobalSum()).map(new GaussionAicTransform2()).withBroadcastSet(dataSet2, "deviance");
        }
        return withBroadcastSet;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Row residualRow(Row row, int i, FamilyLink familyLink, double[] dArr, double d) {
        double[] dArr2 = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr2[i2] = ((Number) row.getField(i2)).doubleValue();
        }
        double doubleValue = ((Double) row.getField(i)).doubleValue();
        double doubleValue2 = ((Double) row.getField(i + 1)).doubleValue();
        double predict = predict(dArr, d, dArr2, ((Double) row.getField(i + 2)).doubleValue(), familyLink);
        double devianceResiduals = devianceResiduals(familyLink.getFamily(), doubleValue, predict, doubleValue2);
        double pearResiduals = pearResiduals(familyLink.getFamily(), doubleValue, predict, doubleValue2);
        double workingResiduals = workingResiduals(familyLink.getLink(), doubleValue, predict, doubleValue2);
        double responseResiduals = responseResiduals(doubleValue, predict);
        int arity = row.getArity();
        Row row2 = new Row(arity + 5);
        for (int i3 = 0; i3 < arity; i3++) {
            row2.setField(i3, row.getField(i3));
        }
        row2.setField(arity, Double.valueOf(predict));
        row2.setField(arity + 1, Double.valueOf(devianceResiduals));
        row2.setField(arity + 2, Double.valueOf(pearResiduals));
        row2.setField(arity + 3, Double.valueOf(workingResiduals));
        row2.setField(arity + 4, Double.valueOf(responseResiduals));
        return row2;
    }

    private static double devianceResiduals(FamilyFunction familyFunction, double d, double d2, double d3) {
        double sqrt = Math.sqrt(Math.max(familyFunction.deviance(d, d2, d3), Criteria.INVALID_GAIN));
        if (d <= d2) {
            sqrt *= -1.0d;
        }
        return sqrt;
    }

    private static double pearResiduals(FamilyFunction familyFunction, double d, double d2, double d3) {
        return ((d - d2) * Math.sqrt(d3)) / Math.sqrt(familyFunction.variance(d2));
    }

    private static double workingResiduals(LinkFunction linkFunction, double d, double d2, double d3) {
        return (d - d2) * linkFunction.derivative(d2);
    }

    private static double responseResiduals(double d, double d2) {
        return d - d2;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Row preProcRow(Row row, int[] iArr, int i, int i2, int i3) {
        int length = iArr.length;
        Row row2 = new Row(length + 3);
        for (int i4 = 0; i4 < length; i4++) {
            row2.setField(i4, row.getField(iArr[i4]));
        }
        row2.setField(length, row.getField(i));
        if (i2 == -1) {
            row2.setField(length + 1, Double.valueOf(1.0d));
        } else {
            row2.setField(length + 1, row.getField(i2));
        }
        if (i3 == -1) {
            row2.setField(length + 2, Double.valueOf(Criteria.INVALID_GAIN));
        } else {
            row2.setField(length + 2, row.getField(i3));
        }
        return row2;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double lnGamma(double d) {
        if (d <= Criteria.INVALID_GAIN) {
            throw new AkIllegalOperatorParameterException("para is out of range!");
        }
        double d2 = ckLanczos[0];
        for (int i = 1; i < ckLanczos.length; i++) {
            d2 += ckLanczos[i] / ((d + i) - 1.0d);
        }
        double d3 = (d + gLanczos) - 0.5d;
        return (((d - 0.5d) * Math.log(d3)) - d3) + Math.log(Math.sqrt(6.283185307179586d) * d2);
    }
}
