package com.alibaba.alink.operator.common.statistics.basicstatistic;

import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.linalg.DenseMatrix;
import java.util.Map;

/* loaded from: input_file:com/alibaba/alink/operator/common/statistics/basicstatistic/VectorSummarizerUtil.class */
public class VectorSummarizerUtil {
    public static BaseVectorSummarizer merge(BaseVectorSummarizer baseVectorSummarizer, BaseVectorSummarizer baseVectorSummarizer2) {
        if ((baseVectorSummarizer instanceof SparseVectorSummarizer) && (baseVectorSummarizer2 instanceof SparseVectorSummarizer)) {
            return merge((SparseVectorSummarizer) baseVectorSummarizer, (SparseVectorSummarizer) baseVectorSummarizer2);
        }
        if ((baseVectorSummarizer instanceof SparseVectorSummarizer) && (baseVectorSummarizer2 instanceof DenseVectorSummarizer)) {
            return merge((SparseVectorSummarizer) baseVectorSummarizer, (DenseVectorSummarizer) baseVectorSummarizer2);
        }
        if ((baseVectorSummarizer instanceof DenseVectorSummarizer) && (baseVectorSummarizer2 instanceof SparseVectorSummarizer)) {
            return merge((DenseVectorSummarizer) baseVectorSummarizer, (SparseVectorSummarizer) baseVectorSummarizer2);
        }
        if ((baseVectorSummarizer instanceof DenseVectorSummarizer) && (baseVectorSummarizer2 instanceof DenseVectorSummarizer)) {
            return merge((DenseVectorSummarizer) baseVectorSummarizer, (DenseVectorSummarizer) baseVectorSummarizer2);
        }
        throw new AkUnsupportedOperationException("");
    }

    private static DenseVectorSummarizer merge(DenseVectorSummarizer denseVectorSummarizer, DenseVectorSummarizer denseVectorSummarizer2) {
        if (denseVectorSummarizer2.count == 0) {
            return denseVectorSummarizer;
        }
        if (denseVectorSummarizer.count == 0) {
            denseVectorSummarizer.count = denseVectorSummarizer2.count;
            denseVectorSummarizer.sum = denseVectorSummarizer2.sum.mo136clone();
            denseVectorSummarizer.squareSum = denseVectorSummarizer2.squareSum.mo136clone();
            denseVectorSummarizer.normL1 = denseVectorSummarizer2.normL1.mo136clone();
            denseVectorSummarizer.min = denseVectorSummarizer2.min.mo136clone();
            denseVectorSummarizer.max = denseVectorSummarizer2.max.mo136clone();
            denseVectorSummarizer.numNonZero = denseVectorSummarizer2.numNonZero.mo136clone();
            if (denseVectorSummarizer2.outerProduct != null) {
                denseVectorSummarizer.outerProduct = denseVectorSummarizer2.outerProduct.m134clone();
            }
            return denseVectorSummarizer;
        }
        int size = denseVectorSummarizer.sum.size();
        int size2 = denseVectorSummarizer2.sum.size();
        if (size < size2) {
            return merge(denseVectorSummarizer2.copy(), denseVectorSummarizer);
        }
        denseVectorSummarizer.count += denseVectorSummarizer2.count;
        for (int i = 0; i < size2; i++) {
            denseVectorSummarizer.sum.add(i, denseVectorSummarizer2.sum.get(i));
            denseVectorSummarizer.squareSum.add(i, denseVectorSummarizer2.squareSum.get(i));
            denseVectorSummarizer.normL1.add(i, denseVectorSummarizer2.normL1.get(i));
            denseVectorSummarizer.min.set(i, Math.min(denseVectorSummarizer.min.get(i), denseVectorSummarizer2.min.get(i)));
            denseVectorSummarizer.max.set(i, Math.max(denseVectorSummarizer.max.get(i), denseVectorSummarizer2.max.get(i)));
            denseVectorSummarizer.numNonZero.add(i, denseVectorSummarizer2.numNonZero.get(i));
        }
        if (denseVectorSummarizer.outerProduct != null && denseVectorSummarizer2.outerProduct != null) {
            for (int i2 = 0; i2 < size2; i2++) {
                for (int i3 = 0; i3 < size2; i3++) {
                    denseVectorSummarizer.outerProduct.add(i2, i3, denseVectorSummarizer2.outerProduct.get(i2, i3));
                }
            }
        } else if (denseVectorSummarizer.outerProduct == null && denseVectorSummarizer2.outerProduct != null) {
            denseVectorSummarizer.outerProduct = denseVectorSummarizer2.outerProduct.m134clone();
        }
        return denseVectorSummarizer;
    }

    private static SparseVectorSummarizer merge(DenseVectorSummarizer denseVectorSummarizer, SparseVectorSummarizer sparseVectorSummarizer) {
        return merge(sparseVectorSummarizer, denseVectorSummarizer);
    }

    private static SparseVectorSummarizer merge(SparseVectorSummarizer sparseVectorSummarizer, DenseVectorSummarizer denseVectorSummarizer) {
        if (denseVectorSummarizer.count != 0) {
            sparseVectorSummarizer.count += denseVectorSummarizer.count;
            for (int i = 0; i < denseVectorSummarizer.sum.size(); i++) {
                VectorStatCol vectorStatCol = new VectorStatCol();
                vectorStatCol.numNonZero = (long) denseVectorSummarizer.numNonZero.get(i);
                vectorStatCol.sum = denseVectorSummarizer.sum.get(i);
                vectorStatCol.squareSum = denseVectorSummarizer.squareSum.get(i);
                vectorStatCol.min = denseVectorSummarizer.min.get(i);
                vectorStatCol.max = denseVectorSummarizer.max.get(i);
                vectorStatCol.normL1 = denseVectorSummarizer.normL1.get(i);
                if (sparseVectorSummarizer.cols.containsKey(Integer.valueOf(i))) {
                    sparseVectorSummarizer.cols.get(Integer.valueOf(i)).merge(vectorStatCol);
                } else {
                    sparseVectorSummarizer.cols.put(Integer.valueOf(i), vectorStatCol);
                }
            }
        }
        if (sparseVectorSummarizer.outerProduct != null && denseVectorSummarizer.outerProduct != null) {
            int numRows = denseVectorSummarizer.outerProduct.numRows();
            if (denseVectorSummarizer.outerProduct.numRows() > sparseVectorSummarizer.outerProduct.numRows()) {
                sparseVectorSummarizer.outerProduct = plusEqual(DenseMatrix.zeros(numRows, numRows), sparseVectorSummarizer.outerProduct);
            }
            sparseVectorSummarizer.outerProduct = plusEqual(sparseVectorSummarizer.outerProduct, denseVectorSummarizer.outerProduct);
        } else if (sparseVectorSummarizer.outerProduct == null && denseVectorSummarizer.outerProduct != null) {
            sparseVectorSummarizer.outerProduct = denseVectorSummarizer.outerProduct.m134clone();
        }
        return sparseVectorSummarizer;
    }

    private static SparseVectorSummarizer merge(SparseVectorSummarizer sparseVectorSummarizer, SparseVectorSummarizer sparseVectorSummarizer2) {
        sparseVectorSummarizer.count += sparseVectorSummarizer2.count;
        sparseVectorSummarizer.colNum = Math.max(sparseVectorSummarizer2.colNum, sparseVectorSummarizer.colNum);
        for (Map.Entry<Integer, VectorStatCol> entry : sparseVectorSummarizer2.cols.entrySet()) {
            int intValue = entry.getKey().intValue();
            if (sparseVectorSummarizer.cols.containsKey(Integer.valueOf(intValue))) {
                sparseVectorSummarizer.cols.get(Integer.valueOf(intValue)).merge(entry.getValue());
            } else {
                sparseVectorSummarizer.cols.put(Integer.valueOf(intValue), entry.getValue());
            }
        }
        if (sparseVectorSummarizer.outerProduct != null && sparseVectorSummarizer2.outerProduct != null) {
            int numRows = sparseVectorSummarizer2.outerProduct.numRows();
            if (sparseVectorSummarizer2.outerProduct.numRows() > sparseVectorSummarizer.outerProduct.numRows()) {
                sparseVectorSummarizer.outerProduct = plusEqual(DenseMatrix.zeros(numRows, numRows), sparseVectorSummarizer.outerProduct);
            }
            sparseVectorSummarizer.outerProduct = plusEqual(sparseVectorSummarizer.outerProduct, sparseVectorSummarizer2.outerProduct);
        } else if (sparseVectorSummarizer.outerProduct == null && sparseVectorSummarizer2.outerProduct != null) {
            sparseVectorSummarizer.outerProduct = sparseVectorSummarizer2.outerProduct.m134clone();
        }
        return sparseVectorSummarizer;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DenseMatrix plusEqual(DenseMatrix denseMatrix, DenseMatrix denseMatrix2) {
        for (int i = 0; i < denseMatrix2.numRows(); i++) {
            for (int i2 = 0; i2 < denseMatrix2.numCols(); i2++) {
                denseMatrix.add(i, i2, denseMatrix2.get(i, i2));
            }
        }
        return denseMatrix;
    }
}
