package com.alibaba.alink.operator.common.outlier;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.operator.common.distance.FastDistanceVectorData;
import com.alibaba.alink.operator.common.similarity.KDTree;
import com.alibaba.alink.params.outlier.HasKDEKernelType;
import com.alibaba.alink.params.outlier.KdeDetectorParams;
import com.alibaba.alink.params.shared.clustering.HasFastDistanceType;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.math3.special.Gamma;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/outlier/KdeDetector.class */
public class KdeDetector extends OutlierDetector {
    private static final double EPS = 1.0E-18d;
    private final HasFastDistanceType.DistanceType distanceType;
    private final HasKDEKernelType.KernelType kernelType;
    private int numNeighbors;
    private final double threshold;
    private final double bandwidth;
    private static double LOG_2PI = Math.log(6.283185307179586d);
    private static double LOG_PI = Math.log(3.141592653589793d);

    public KdeDetector(TableSchema tableSchema, Params params) {
        super(tableSchema, params);
        this.numNeighbors = ((Integer) params.get(KdeDetectorParams.NUM_NEIGHBORS)).intValue();
        this.threshold = ((Double) params.get(KdeDetectorParams.OUTLIER_THRESHOLD)).doubleValue();
        this.kernelType = (HasKDEKernelType.KernelType) params.get(KdeDetectorParams.KDE_KERNEL_TYPE);
        this.bandwidth = ((Double) params.get(KdeDetectorParams.BANDWIDTH)).doubleValue();
        this.distanceType = (HasFastDistanceType.DistanceType) params.get(KdeDetectorParams.DISTANCE_TYPE);
    }

    @Override // com.alibaba.alink.operator.common.outlier.OutlierDetector
    protected Tuple3<Boolean, Double, Map<String, String>>[] detect(MTable mTable, boolean z) throws Exception {
        Vector[] vectors = OutlierUtil.getVectors(mTable, this.params);
        FastDistance fastDistance = this.distanceType.getFastDistance();
        int length = vectors.length;
        if (0 == length) {
            return new Tuple3[0];
        }
        if (this.numNeighbors < 0 || this.numNeighbors > length) {
            this.numNeighbors = length;
        }
        int size = vectors[0].size();
        if ((vectors[0] instanceof SparseVector) && vectors[0].size() < 0) {
            for (Vector vector : vectors) {
                int[] indices = ((SparseVector) vector).getIndices();
                size = Math.max(size, indices[indices.length - 1] + 1);
            }
        }
        FastDistanceVectorData[] fastDistanceVectorDataArr = new FastDistanceVectorData[length];
        for (int i = 0; i < length; i++) {
            fastDistanceVectorDataArr[i] = fastDistance.prepareVectorData(Row.of(new Object[]{vectors[i], Integer.valueOf(i)}), 0, 1);
        }
        KDTree kDTree = null;
        if (this.numNeighbors < length) {
            new KDTree((FastDistanceVectorData[]) fastDistanceVectorDataArr.clone(), size, fastDistance);
            kDTree.buildTree();
        }
        double[][] dArr = new double[length][this.numNeighbors];
        for (int i2 = 0; i2 < length; i2++) {
            if (this.numNeighbors < length) {
                int i3 = 0;
                for (Tuple2<Double, Row> tuple2 : kDTree.getTopN(this.numNeighbors, fastDistanceVectorDataArr[i2])) {
                    ((Integer) ((Row) tuple2.f1).getField(0)).intValue();
                    dArr[i2][i3] = Math.max(((Double) tuple2.f0).doubleValue(), EPS);
                    i3++;
                }
            } else {
                for (int i4 = 0; i4 < length; i4++) {
                    dArr[i2][i4] = fastDistance.calc(fastDistanceVectorDataArr[i2].getVector(), fastDistanceVectorDataArr[i4].getVector());
                }
            }
        }
        double[] dArr2 = new double[length];
        double[] dArr3 = new double[length];
        for (int i5 = 0; i5 < length; i5++) {
            int i6 = 0;
            while (i6 < this.numNeighbors) {
                double logKernelFunction = logKernelFunction(this.bandwidth, dArr[i5][i6], this.kernelType);
                dArr2[i5] = i6 == 0 ? logKernelFunction : logAddExp(dArr2[i5], logKernelFunction);
                i6++;
            }
            dArr2[i5] = (Math.exp(logKernelNorm(this.bandwidth, size, this.kernelType)) * Math.exp(dArr2[i5])) / this.numNeighbors;
            dArr3[i5] = 1.0d / Math.max(dArr2[i5], EPS);
        }
        Tuple3<Boolean, Double, Map<String, String>>[] tuple3Arr = z ? new Tuple3[1] : new Tuple3[length];
        if (z) {
            HashMap hashMap = new HashMap();
            hashMap.put("KDE", String.valueOf(dArr2[length - 1]));
            tuple3Arr[0] = Tuple3.of(Boolean.valueOf(dArr3[length - 1] > this.threshold), Double.valueOf(dArr3[length - 1]), hashMap);
        } else {
            for (int i7 = 0; i7 < length; i7++) {
                HashMap hashMap2 = new HashMap();
                hashMap2.put("KDE", String.valueOf(dArr2[i7]));
                tuple3Arr[i7] = Tuple3.of(Boolean.valueOf(dArr3[i7] > this.threshold), Double.valueOf(dArr3[i7]), hashMap2);
            }
        }
        return tuple3Arr;
    }

    private static double logVn(int i) {
        return ((0.5d * i) * LOG_PI) - Gamma.logGamma((0.5d * i) + 1.0d);
    }

    private static double logSn(int i) {
        return LOG_2PI + logVn(i - 1);
    }

    private static double logAddExp(double d, double d2) {
        return Math.log(Math.exp(d) + Math.exp(d2));
    }

    private double logKernelNorm(double d, int i, HasKDEKernelType.KernelType kernelType) {
        double logVn;
        switch (kernelType) {
            case GAUSSIAN:
                logVn = 0.5d * i * LOG_2PI;
                break;
            case LINEAR:
                logVn = logVn(i) - Math.log(i + 1.0d);
                break;
            default:
                throw new IllegalArgumentException("KDE Kernel not recognized");
        }
        return (-logVn) - (i * Math.log(d));
    }

    private double logKernelFunction(double d, double d2, HasKDEKernelType.KernelType kernelType) {
        switch (kernelType) {
            case GAUSSIAN:
                return ((-0.5d) * (d2 * d2)) / (d * d);
            case LINEAR:
                if (d2 < d) {
                    return Math.log(1.0d - (d2 / d));
                }
                return Double.NEGATIVE_INFINITY;
            default:
                throw new IllegalArgumentException("KDE Kernel not recognized");
        }
    }
}
