package com.alibaba.alink.operator.common.outlier;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.fe.define.BaseStatFeatures;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.probabilistic.IDF;
import com.alibaba.alink.params.outlier.CooksDistanceDetectorParams;
import com.alibaba.alink.params.outlier.WithMultiVarParams;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

/* loaded from: input_file:com/alibaba/alink/operator/common/outlier/CooksDistanceDetector.class */
public class CooksDistanceDetector extends OutlierDetector {
    public CooksDistanceDetector(TableSchema tableSchema, Params params) {
        super(tableSchema, params);
    }

    @Override // com.alibaba.alink.operator.common.outlier.OutlierDetector
    public Tuple3<Boolean, Double, Map<String, String>>[] detect(MTable mTable, boolean z) throws Exception {
        MTable mTableWithLabel = getMTableWithLabel(mTable, this.params);
        int numRow = mTable.getNumRow();
        int numCol = mTable.getNumCol();
        if (numRow < numCol) {
            throw new AkIllegalOperatorParameterException("rowNum must be larger than colNum-1.");
        }
        DenseMatrix denseMatrix = new DenseMatrix(numRow, numCol);
        DenseMatrix denseMatrix2 = new DenseMatrix(numRow, 1);
        for (int i = 0; i < numRow; i++) {
            denseMatrix2.set(i, 0, getDoubleValue(mTableWithLabel, i, numCol - 1));
            for (int i2 = 0; i2 < numCol; i2++) {
                if (numCol - 1 == i2) {
                    denseMatrix.set(i, i2, 1.0d);
                } else {
                    denseMatrix.set(i, i2, getDoubleValue(mTableWithLabel, i, i2));
                }
            }
        }
        double[] cooksDistance = cooksDistance(denseMatrix, denseMatrix2);
        double F = IDF.F(0.95d, numCol, numRow - numCol);
        Tuple3<Boolean, Double, Map<String, String>>[] tuple3Arr = new Tuple3[numRow];
        for (int i3 = 0; i3 < numRow; i3++) {
            HashMap hashMap = new HashMap();
            hashMap.put("distance", String.valueOf(cooksDistance[i3]));
            hashMap.put(BaseStatFeatures.NUMBER, String.valueOf(numRow));
            hashMap.put("p", String.valueOf(numCol));
            hashMap.put("f", String.valueOf(F));
            tuple3Arr[i3] = Tuple3.of(Boolean.valueOf(cooksDistance[i3] > F), Double.valueOf(cooksDistance[i3]), hashMap);
        }
        return tuple3Arr;
    }

    static MTable getMTableWithLabel(MTable mTable, Params params) {
        if (params.contains(WithMultiVarParams.VECTOR_COL)) {
            Tuple2<Vector[], Integer> selectVectorCol = OutlierUtil.selectVectorCol(mTable, (String) params.get(WithMultiVarParams.VECTOR_COL));
            return OutlierUtil.vectorsToMTable((Vector[]) selectVectorCol.f0, ((Integer) selectVectorCol.f1).intValue());
        }
        String[] strArr = (String[]) params.get(WithMultiVarParams.FEATURE_COLS);
        String str = (String) params.get(CooksDistanceDetectorParams.LABEL_COL);
        String[] colNames = mTable.getColNames();
        if (null != str && !str.trim().isEmpty()) {
            if (strArr == null) {
                strArr = new String[colNames.length];
                strArr[strArr.length - 1] = str;
                int i = 0;
                for (String str2 : colNames) {
                    if (!str2.equals(str)) {
                        int i2 = i;
                        i++;
                        strArr[i2] = str2;
                    }
                }
            } else {
                String[] strArr2 = new String[strArr.length + 1];
                System.arraycopy(strArr, 0, strArr2, 0, strArr.length);
                strArr2[strArr.length] = str;
                strArr = strArr2;
            }
        }
        if (strArr == null) {
            strArr = colNames;
        }
        return OutlierUtil.selectFeatures(mTable, strArr);
    }

    static double[] cooksDistance(DenseMatrix denseMatrix, DenseMatrix denseMatrix2) {
        int numRows = denseMatrix.numRows();
        int numCols = denseMatrix.numCols();
        DenseMatrix transpose = denseMatrix.transpose();
        DenseMatrix pseudoInverse = transpose.multiplies(denseMatrix).pseudoInverse();
        DenseMatrix multiplies = pseudoInverse.multiplies(transpose.multiplies(denseMatrix2));
        double[] dArr = new double[numRows];
        for (int i = 0; i < numRows; i++) {
            DenseMatrix subMatrix = denseMatrix.getSubMatrix(i, i + 1, 0, numCols);
            dArr[i] = subMatrix.multiplies(pseudoInverse).multiplies(subMatrix.transpose()).get(0, 0);
        }
        DenseMatrix minus = denseMatrix2.minus(denseMatrix.multiplies(multiplies));
        double d = minus.transpose().multiplies(minus).get(0, 0) / (numRows - numCols);
        double[] dArr2 = new double[numRows];
        for (int i2 = 0; i2 < numRows; i2++) {
            dArr2[i2] = (((Math.pow(minus.get(i2, 0), 2.0d) * dArr[i2]) / d) / Math.pow(1.0d - dArr[i2], 2.0d)) / numCols;
        }
        return dArr2;
    }

    private double getDoubleValue(MTable mTable, int i, int i2) {
        if (null == mTable.getEntry(i, i2)) {
            throw new AkIllegalOperatorParameterException(String.format("the entry of %s row and %s col is null.", Integer.valueOf(i), Integer.valueOf(i2)));
        }
        return ((Number) mTable.getEntry(i, i2)).doubleValue();
    }
}
