package com.alibaba.alink.operator.batch.statistics;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.jama.JMatrixFunc;
import com.alibaba.alink.common.utils.AlinkSerializable;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.common.viz.VizData;
import com.alibaba.alink.common.viz.VizDataWriterInterface;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper;
import com.alibaba.alink.operator.common.statistics.statistics.SummaryResultTable;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.statistics.HasStatLevel_L1;
import com.alibaba.alink.params.statistics.MultiCollinearityBatchParams;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.DATA)})
@ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES})
@NameCn("多重共线性")
@NameEn("MultiCollinearity")
/* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/MultiCollinearityBatchOp.class */
public class MultiCollinearityBatchOp extends BatchOperator<MultiCollinearityBatchOp> implements MultiCollinearityBatchParams<MultiCollinearityBatchOp> {
    private static final long serialVersionUID = -3276749170439192468L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/MultiCollinearityBatchOp$Multicollinearity.class */
    public static class Multicollinearity {
        public String[] nameX;
        public double[] VIF;
        public double[] TOL;
        public double[] eigenValues;
        public double[] CI;
        public double[][] VarProp;
        double[][] correlation;
        double kappa;
        double lambdaMin;
        double lambdaMax;
        double[] vectorMin;

        public static Multicollinearity calc(SummaryResultTable summaryResultTable, String[] strArr) throws Exception {
            if (summaryResultTable == null) {
                throw new Exception("srt is null!");
            }
            if (strArr == null) {
                strArr = summaryResultTable.colNames;
            }
            int length = strArr.length;
            int[] iArr = new int[length];
            for (int i = 0; i < length; i++) {
                iArr[i] = TableUtil.findColIndexWithAssert(summaryResultTable.colNames, strArr[i]);
                Class cls = summaryResultTable.col(iArr[i]).dataType;
                if (cls != Double.class && cls != Long.class && cls != Boolean.class) {
                    throw new Exception("col type must be double, bigint , boolean!");
                }
                if (summaryResultTable.col(iArr[i]).count == 0) {
                    throw new Exception(strArr[i] + " count is zero, please choose cols again!");
                }
                if (summaryResultTable.col(iArr[i]).countMissValue > 0 || summaryResultTable.col(iArr[i]).countNanValue > 0) {
                    throw new Exception("col " + strArr[i] + " has null value or nan value!");
                }
            }
            double[][] corr = summaryResultTable.getCorr();
            double[][] dArr = new double[length][length];
            for (int i2 = 0; i2 < length; i2++) {
                for (int i3 = 0; i3 < length; i3++) {
                    dArr[i2][i3] = corr[iArr[i2]][iArr[i3]];
                }
            }
            DenseMatrix[] eig = JMatrixFunc.eig(new DenseMatrix(dArr));
            Multicollinearity multicollinearity = new Multicollinearity();
            multicollinearity.correlation = dArr;
            multicollinearity.nameX = new String[length];
            multicollinearity.vectorMin = new double[length];
            for (int i4 = 0; i4 < length; i4++) {
                multicollinearity.nameX[i4] = strArr[i4];
                multicollinearity.vectorMin[i4] = eig[0].get(i4, 0);
            }
            multicollinearity.eigenValues = new double[length];
            for (int i5 = 0; i5 < length; i5++) {
                double d = eig[1].get(i5, i5);
                if (d < 1.0E-12d) {
                    d = 1.0E-12d;
                }
                multicollinearity.eigenValues[(length - 1) - i5] = d;
            }
            multicollinearity.lambdaMax = multicollinearity.eigenValues[0];
            multicollinearity.lambdaMin = multicollinearity.eigenValues[length - 1];
            multicollinearity.kappa = multicollinearity.lambdaMax / multicollinearity.lambdaMin;
            multicollinearity.CI = new double[length];
            for (int i6 = 0; i6 < length; i6++) {
                multicollinearity.CI[i6] = Math.sqrt(multicollinearity.lambdaMax / multicollinearity.eigenValues[i6]);
            }
            double[][] dArr2 = new double[length][length];
            double[] dArr3 = new double[length];
            for (int i7 = 0; i7 < length; i7++) {
                dArr3[i7] = 0.0d;
                for (int i8 = 0; i8 < length; i8++) {
                    dArr2[i8][i7] = eig[0].get(i7, (length - 1) - i8);
                    dArr2[i8][i7] = (dArr2[i8][i7] * dArr2[i8][i7]) / multicollinearity.eigenValues[i8];
                    int i9 = i7;
                    dArr3[i9] = dArr3[i9] + dArr2[i8][i7];
                }
            }
            multicollinearity.VarProp = new double[length][length];
            for (int i10 = 0; i10 < length; i10++) {
                for (int i11 = 0; i11 < length; i11++) {
                    multicollinearity.VarProp[i10][i11] = dArr2[i10][i11] / dArr3[i11];
                }
            }
            multicollinearity.VIF = new double[length];
            multicollinearity.TOL = new double[length];
            ArrayList arrayList = new ArrayList();
            arrayList.addAll(Arrays.asList(strArr));
            double[][] cov = summaryResultTable.getCov();
            for (int length2 = strArr.length - 1; length2 >= 0; length2--) {
                ArrayList arrayList2 = new ArrayList();
                arrayList2.addAll(Arrays.asList(strArr));
                arrayList2.remove(length2);
                multicollinearity.VIF[length2] = 1.0d / (1.0d - getR2(summaryResultTable, strArr[length2], (String[]) arrayList2.toArray(new String[0]), cov));
                multicollinearity.TOL[length2] = 1.0d / multicollinearity.VIF[length2];
                if (multicollinearity.VIF[length2] > 100000.0d) {
                    arrayList.remove(strArr[length2]);
                }
            }
            for (int length3 = strArr.length - 1; length3 >= 0; length3--) {
                if (multicollinearity.VIF[length3] <= 100000.0d) {
                    ArrayList arrayList3 = new ArrayList();
                    arrayList3.addAll(arrayList);
                    arrayList3.remove(strArr[length3]);
                    multicollinearity.VIF[length3] = 1.0d / (1.0d - getR2(summaryResultTable, strArr[length3], (String[]) arrayList3.toArray(new String[0]), cov));
                    multicollinearity.TOL[length3] = 1.0d / multicollinearity.VIF[length3];
                }
            }
            return multicollinearity;
        }

        static double getR2(SummaryResultTable summaryResultTable, int i, int[] iArr, String str, String[] strArr, double[][] dArr) throws Exception {
            DenseMatrix solveLS;
            if (summaryResultTable.col(0).countTotal == 0) {
                throw new Exception("table is empty!");
            }
            if (summaryResultTable.col(0).countTotal < strArr.length) {
                throw new Exception("record size Less than features size!");
            }
            int length = iArr.length;
            long j = summaryResultTable.col(i).count;
            if (j == 0) {
                throw new Exception("Y valid value num is zero!");
            }
            ArrayList arrayList = new ArrayList();
            for (int i2 : iArr) {
                if (summaryResultTable.col(i2).count != 0 && dArr[i2][i] != Criteria.INVALID_GAIN) {
                    arrayList.add(Integer.valueOf(i2));
                }
            }
            int[] iArr2 = new int[arrayList.size()];
            for (int i3 = 0; i3 < iArr2.length; i3++) {
                iArr2[i3] = ((Integer) arrayList.get(i3)).intValue();
            }
            int length2 = iArr2.length;
            double[] dArr2 = new double[length2];
            for (int i4 = 0; i4 < length2; i4++) {
                dArr2[i4] = summaryResultTable.col(iArr2[i4]).mean();
            }
            double mean = summaryResultTable.col(i).mean();
            DenseMatrix denseMatrix = new DenseMatrix(length2, length2);
            for (int i5 = 0; i5 < length2; i5++) {
                for (int i6 = 0; i6 < length2; i6++) {
                    denseMatrix.set(i5, i6, dArr[iArr2[i5]][iArr2[i6]]);
                }
            }
            DenseMatrix denseMatrix2 = new DenseMatrix(length2, 1);
            for (int i7 = 0; i7 < length2; i7++) {
                denseMatrix2.set(i7, 0, dArr[iArr2[i7]][i]);
            }
            try {
                solveLS = denseMatrix.solve(denseMatrix2);
            } catch (Exception e) {
                solveLS = denseMatrix.solveLS(denseMatrix2);
            }
            double[] dArr3 = new double[length2 + 1];
            double d = mean;
            for (int i8 = 0; i8 < length2; i8++) {
                dArr3[i8 + 1] = solveLS.get(i8, 0);
                d -= dArr2[i8] * dArr3[i8 + 1];
            }
            dArr3[0] = d;
            double variance = summaryResultTable.col(str).variance() * (summaryResultTable.col(str).count - 1);
            double d2 = dArr3[0] - mean;
            double d3 = Criteria.INVALID_GAIN + (d2 * d2 * j);
            for (int i9 = 0; i9 < length2; i9++) {
                d3 += 2.0d * d2 * summaryResultTable.col(iArr2[i9]).sum * dArr3[i9 + 1];
            }
            for (int i10 = 0; i10 < length2; i10++) {
                for (int i11 = 0; i11 < length2; i11++) {
                    d3 += dArr3[i10 + 1] * dArr3[i11 + 1] * ((dArr[iArr2[i10]][iArr2[i11]] * (j - 1)) + (summaryResultTable.col(iArr2[i10]).mean() * summaryResultTable.col(iArr2[i11]).mean() * j));
                }
            }
            double d4 = d3 / variance;
            if (d4 < Criteria.INVALID_GAIN) {
                d4 = 0.0d;
            } else if (d4 > 1.0d) {
                d4 = 1.0d;
            }
            return d4;
        }

        static double getR2(SummaryResultTable summaryResultTable, String str, String[] strArr, double[][] dArr) throws Exception {
            if (summaryResultTable == null) {
                throw new Exception("srt must not null!");
            }
            String[] strArr2 = summaryResultTable.colNames;
            Class[] clsArr = new Class[strArr2.length];
            for (int i = 0; i < strArr2.length; i++) {
                clsArr[i] = summaryResultTable.col(i).dataType;
            }
            int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(strArr2, str);
            Class cls = clsArr[findColIndexWithAssertAndHint];
            if (cls != Double.class && cls != Long.class) {
                throw new Exception("col type must be double or bigint!");
            }
            if (strArr.length == 0) {
                throw new Exception("nameX must input!");
            }
            for (String str2 : strArr) {
                Class cls2 = clsArr[TableUtil.findColIndexWithAssertAndHint(strArr2, str2)];
                if (cls2 != Double.class && cls2 != Long.class) {
                    throw new Exception("col type must be double or bigint!");
                }
            }
            int length = strArr.length;
            int[] iArr = new int[length];
            for (int i2 = 0; i2 < length; i2++) {
                iArr[i2] = TableUtil.findColIndexWithAssert(summaryResultTable.colNames, strArr[i2]);
            }
            return getR2(summaryResultTable, findColIndexWithAssertAndHint, iArr, str, strArr, dArr);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/MultiCollinearityBatchOp$MulticollinearityFlatMap.class */
    public static class MulticollinearityFlatMap implements FlatMapFunction<SummaryResultTable, Row> {
        private static final long serialVersionUID = 7050574386992532014L;
        private String functionName;
        private VizDataWriterInterface node;

        public MulticollinearityFlatMap(VizDataWriterInterface vizDataWriterInterface, String str) {
            this.functionName = str;
            this.node = vizDataWriterInterface;
        }

        public void flatMap(SummaryResultTable summaryResultTable, Collector<Row> collector) throws Exception {
            try {
                long currentTimeMillis = System.currentTimeMillis();
                Multicollinearity calc = Multicollinearity.calc(summaryResultTable, null);
                MulticollinearityResult multicollinearityResult = new MulticollinearityResult();
                int length = calc.nameX.length;
                int i = length + 4;
                multicollinearityResult.rowNames = calc.nameX;
                multicollinearityResult.colNames = new String[i];
                multicollinearityResult.colNames[0] = "vif";
                multicollinearityResult.colNames[1] = "tol";
                multicollinearityResult.colNames[2] = "eigenvalue";
                multicollinearityResult.colNames[3] = "condition_indx";
                System.arraycopy(calc.nameX, 0, multicollinearityResult.colNames, 4, length);
                multicollinearityResult.data = new double[length][i];
                for (int i2 = 0; i2 < length; i2++) {
                    multicollinearityResult.data[i2][0] = calc.VIF[i2];
                    multicollinearityResult.data[i2][1] = calc.TOL[i2];
                    multicollinearityResult.data[i2][2] = calc.eigenValues[i2];
                    multicollinearityResult.data[i2][3] = calc.CI[i2];
                    System.arraycopy(calc.VarProp[i2], 0, multicollinearityResult.data[i2], 4, length);
                }
                Row row = new Row(7);
                String json = JsonConverter.gson.toJson(multicollinearityResult);
                row.setField(0, this.functionName);
                row.setField(1, "");
                row.setField(2, json);
                row.setField(3, Long.valueOf(currentTimeMillis));
                collector.collect(row);
                int i3 = this.functionName.equals("AllStat") ? 1 : 0;
                ArrayList arrayList = new ArrayList();
                arrayList.add(new VizData(i3, json, currentTimeMillis));
                this.node.writeStreamData(arrayList);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((SummaryResultTable) obj, (Collector<Row>) collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/MultiCollinearityBatchOp$MulticollinearityResult.class */
    public static class MulticollinearityResult implements AlinkSerializable {
        String[] colNames;
        String[] rowNames;
        double[][] data;

        MulticollinearityResult() {
        }
    }

    public MultiCollinearityBatchOp() {
        super(null);
    }

    public MultiCollinearityBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public MultiCollinearityBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        final String[] selectedCols = getSelectedCols();
        TableUtil.assertNumericalCols(checkAndGetFirst.getSchema(), selectedCols);
        setOutput(StatisticsHelper.getSRT(checkAndGetFirst.select(selectedCols), HasStatLevel_L1.StatLevel.L3).map(new MapFunction<SummaryResultTable, Multicollinearity>() { // from class: com.alibaba.alink.operator.batch.statistics.MultiCollinearityBatchOp.1
            public Multicollinearity map(SummaryResultTable summaryResultTable) throws Exception {
                return Multicollinearity.calc(summaryResultTable, selectedCols);
            }
        }).flatMap(new FlatMapFunction<Multicollinearity, Row>() { // from class: com.alibaba.alink.operator.batch.statistics.MultiCollinearityBatchOp.2
            public void flatMap(Multicollinearity multicollinearity, Collector<Row> collector) throws Exception {
                for (int i = 0; i < multicollinearity.nameX.length; i++) {
                    Row row = new Row(5 + selectedCols.length);
                    row.setField(0, multicollinearity.nameX[i]);
                    row.setField(1, Double.valueOf(multicollinearity.VIF[i]));
                    row.setField(2, Double.valueOf(multicollinearity.TOL[i]));
                    row.setField(3, Double.valueOf(multicollinearity.eigenValues[i]));
                    row.setField(4, Double.valueOf(multicollinearity.CI[i]));
                    for (int i2 = 0; i2 < selectedCols.length; i2++) {
                        row.setField(5 + i2, Double.valueOf(multicollinearity.VarProp[i][i2]));
                    }
                    collector.collect(row);
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Multicollinearity) obj, (Collector<Row>) collector);
            }
        }), mergeCols(new String[]{"feature_name", "vif", "tof", "eigenvalue", "condition_index"}, selectedCols), mergeColTypes(new TypeInformation[]{Types.STRING, Types.DOUBLE, Types.DOUBLE, Types.DOUBLE, Types.DOUBLE}, Types.DOUBLE, selectedCols.length));
        return this;
    }

    private String[] mergeCols(String[] strArr, String[] strArr2) {
        String[] strArr3 = new String[strArr.length + strArr2.length];
        System.arraycopy(strArr, 0, strArr3, 0, strArr.length);
        System.arraycopy(strArr2, 0, strArr3, strArr.length, strArr2.length);
        return strArr3;
    }

    private TypeInformation<?>[] mergeColTypes(TypeInformation<?>[] typeInformationArr, TypeInformation<?> typeInformation, int i) {
        TypeInformation<?>[] typeInformationArr2 = new TypeInformation[typeInformationArr.length + i];
        System.arraycopy(typeInformationArr, 0, typeInformationArr2, 0, typeInformationArr.length);
        for (int i2 = 0; i2 < i; i2++) {
            typeInformationArr2[typeInformationArr.length + i2] = typeInformation;
        }
        return typeInformationArr2;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ MultiCollinearityBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
