package com.alibaba.alink.operator.common.statistics;

import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.jama.JMatrixFunc;
import com.alibaba.alink.common.utils.TableUtil;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/statistics/CorrespondenceAnalysis.class */
public class CorrespondenceAnalysis {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/statistics/CorrespondenceAnalysis$GroupFactor.class */
    public static class GroupFactor {
        String rowExpr;
        String colExpr;

        public GroupFactor(String str, String str2) {
            this.rowExpr = str;
            this.colExpr = str2;
        }
    }

    public static CorrespondenceAnalysisResult calc(Iterable<Row> iterable, String str, String str2, String[] strArr) throws Exception {
        Map<GroupFactor, Long> groupCount = getGroupCount(iterable, TableUtil.findColIndexWithAssertAndHint(strArr, str), TableUtil.findColIndexWithAssertAndHint(strArr, str2));
        List<Set<String>> distinctValue = getDistinctValue(groupCount);
        String[] strArr2 = (String[]) distinctValue.get(0).toArray(new String[0]);
        String[] strArr3 = (String[]) distinctValue.get(1).toArray(new String[1]);
        CorrespondenceAnalysisResult calc = calc(getPivotTable(strArr2, strArr3, groupCount));
        calc.rowLegend = str;
        calc.colLegend = str2;
        calc.rowTags = strArr2;
        calc.colTags = strArr3;
        return calc;
    }

    private static Map<GroupFactor, Long> getGroupCount(Iterable<Row> iterable, int i, int i2) {
        TreeMap treeMap = new TreeMap(new Comparator<GroupFactor>() { // from class: com.alibaba.alink.operator.common.statistics.CorrespondenceAnalysis.1
            @Override // java.util.Comparator
            public int compare(GroupFactor groupFactor, GroupFactor groupFactor2) {
                int compareTo = groupFactor.rowExpr.compareTo(groupFactor2.rowExpr);
                return compareTo == 0 ? groupFactor.colExpr.compareTo(groupFactor2.colExpr) : compareTo;
            }
        });
        for (Row row : iterable) {
            GroupFactor groupFactor = new GroupFactor(row.getField(i).toString(), row.getField(i2).toString());
            if (treeMap.containsKey(groupFactor)) {
                treeMap.put(groupFactor, Long.valueOf(((Long) treeMap.get(groupFactor)).longValue() + 1));
            } else {
                treeMap.put(groupFactor, 1L);
            }
        }
        return treeMap;
    }

    private static List<Set<String>> getDistinctValue(Map<GroupFactor, Long> map) {
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        for (GroupFactor groupFactor : map.keySet()) {
            hashSet.add(groupFactor.rowExpr);
            hashSet2.add(groupFactor.colExpr);
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(hashSet);
        arrayList.add(hashSet2);
        return arrayList;
    }

    static double[][] getPivotTable(String[] strArr, String[] strArr2, Map<GroupFactor, Long> map) {
        double[][] dArr = new double[strArr.length][strArr2.length];
        for (int i = 0; i < strArr.length; i++) {
            for (int i2 = 0; i2 < strArr2.length; i2++) {
                if (map.containsKey(new GroupFactor(strArr[i], strArr2[i2]))) {
                    dArr[i][i2] = map.get(r0).longValue();
                } else {
                    dArr[i][i2] = 0.0d;
                }
            }
        }
        return dArr;
    }

    static CorrespondenceAnalysisResult calc(double[][] dArr) throws Exception {
        int length = dArr.length;
        int length2 = dArr[0].length;
        if (length * length2 == 1) {
            throw new Exception("(the number of column expr) * ( number of row expr) must Greater than 2.!");
        }
        double d = 0.0d;
        for (double[] dArr2 : dArr) {
            for (int i = 0; i < length2; i++) {
                d += dArr2[i];
            }
        }
        double[] dArr3 = new double[length];
        double[] dArr4 = new double[length2];
        for (int i2 = 0; i2 < length; i2++) {
            double d2 = 0.0d;
            for (int i3 = 0; i3 < length2; i3++) {
                d2 += dArr[i2][i3];
            }
            dArr3[i2] = d2;
        }
        for (int i4 = 0; i4 < length2; i4++) {
            double d3 = 0.0d;
            for (double[] dArr5 : dArr) {
                d3 += dArr5[i4];
            }
            dArr4[i4] = d3;
        }
        double[][] dArr6 = new double[length][length2];
        for (int i5 = 0; i5 < length; i5++) {
            for (int i6 = 0; i6 < length2; i6++) {
                dArr6[i5][i6] = dArr[i5][i6] / d;
            }
        }
        double[] dArr7 = new double[length];
        double[] dArr8 = new double[length2];
        for (int i7 = 0; i7 < length; i7++) {
            double d4 = 0.0d;
            for (int i8 = 0; i8 < length2; i8++) {
                d4 += dArr6[i7][i8];
            }
            dArr7[i7] = d4;
        }
        for (int i9 = 0; i9 < length2; i9++) {
            double d5 = 0.0d;
            for (int i10 = 0; i10 < length; i10++) {
                d5 += dArr6[i10][i9];
            }
            dArr8[i9] = d5;
        }
        double[][] dArr9 = new double[length][length2];
        for (int i11 = 0; i11 < length; i11++) {
            for (int i12 = 0; i12 < length2; i12++) {
                double d6 = dArr7[i11] * dArr8[i12];
                dArr9[i11][i12] = (dArr6[i11][i12] - d6) / Math.sqrt(d6);
            }
        }
        double d7 = 0.0d;
        for (int i13 = 0; i13 < length; i13++) {
            for (int i14 = 0; i14 < length2; i14++) {
                d7 += dArr9[i13][i14] * dArr9[i13][i14];
            }
        }
        double d8 = d7 * d;
        DenseMatrix[] svd = JMatrixFunc.svd(new DenseMatrix(dArr9));
        int min = Math.min(length, length2);
        DenseMatrix multiplies = svd[0].multiplies(svd[1]);
        for (int i15 = 0; i15 < length; i15++) {
            double sqrt = Math.sqrt(dArr7[i15]);
            for (int i16 = 0; i16 < min; i16++) {
                multiplies.set(i15, i16, multiplies.get(i15, i16) / sqrt);
            }
        }
        DenseMatrix multiplies2 = svd[2].multiplies(svd[1]);
        for (int i17 = 0; i17 < length2; i17++) {
            double sqrt2 = Math.sqrt(dArr8[i17]);
            for (int i18 = 0; i18 < min; i18++) {
                multiplies2.set(i17, i18, multiplies2.get(i17, i18) / sqrt2);
            }
        }
        CorrespondenceAnalysisResult correspondenceAnalysisResult = new CorrespondenceAnalysisResult();
        correspondenceAnalysisResult.nrow = length;
        correspondenceAnalysisResult.ncol = length2;
        correspondenceAnalysisResult.rowPos = new double[length][2];
        if (multiplies.numCols() > 1) {
            for (int i19 = 0; i19 < length; i19++) {
                correspondenceAnalysisResult.rowPos[i19][0] = multiplies.get(i19, 0);
                correspondenceAnalysisResult.rowPos[i19][1] = multiplies.get(i19, 1);
            }
        } else {
            for (int i20 = 0; i20 < length; i20++) {
                correspondenceAnalysisResult.rowPos[i20][0] = multiplies.get(i20, 0);
                correspondenceAnalysisResult.rowPos[i20][1] = 0.0d;
            }
        }
        correspondenceAnalysisResult.colPos = new double[length2][2];
        if (multiplies.numCols() > 1) {
            for (int i21 = 0; i21 < length2; i21++) {
                correspondenceAnalysisResult.colPos[i21][0] = multiplies2.get(i21, 0);
                correspondenceAnalysisResult.colPos[i21][1] = multiplies2.get(i21, 1);
            }
        } else {
            for (int i22 = 0; i22 < length2; i22++) {
                correspondenceAnalysisResult.colPos[i22][0] = multiplies2.get(i22, 0);
                correspondenceAnalysisResult.colPos[i22][1] = 0.0d;
            }
        }
        correspondenceAnalysisResult.sv = new double[2];
        correspondenceAnalysisResult.sv[0] = svd[1].get(0, 0);
        if (multiplies.numCols() > 1) {
            correspondenceAnalysisResult.sv[1] = svd[1].get(1, 1);
        } else {
            correspondenceAnalysisResult.sv[1] = 0.0d;
        }
        correspondenceAnalysisResult.pct = new double[2];
        correspondenceAnalysisResult.pct[0] = ((d * correspondenceAnalysisResult.sv[0]) * correspondenceAnalysisResult.sv[0]) / d8;
        correspondenceAnalysisResult.pct[1] = ((d * correspondenceAnalysisResult.sv[1]) * correspondenceAnalysisResult.sv[1]) / d8;
        return correspondenceAnalysisResult;
    }
}
