package com.alibaba.alink.operator.local.feature;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.RowCollector;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.feature.ExclusiveFeatureBundleModelDataConverter;
import com.alibaba.alink.operator.common.feature.FeatureBundles;
import com.alibaba.alink.operator.local.LocalOperator;
import com.alibaba.alink.params.feature.ExclusiveFeatureBundlePredictParams;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

@NameCn("互斥特征捆绑模型训练")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.ExclusiveFeatureBundle")
/* loaded from: input_file:com/alibaba/alink/operator/local/feature/ExclusiveFeatureBundleTrainLocalOp.class */
public class ExclusiveFeatureBundleTrainLocalOp extends LocalOperator<ExclusiveFeatureBundleTrainLocalOp> implements ExclusiveFeatureBundlePredictParams<ExclusiveFeatureBundleTrainLocalOp> {
    public ExclusiveFeatureBundleTrainLocalOp() {
        this(new Params());
    }

    public ExclusiveFeatureBundleTrainLocalOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.local.LocalOperator
    public ExclusiveFeatureBundleTrainLocalOp linkFrom(LocalOperator<?>... localOperatorArr) {
        LocalOperator<?> checkAndGetFirst = checkAndGetFirst(localOperatorArr);
        int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getSchema(), getSparseVectorCol());
        MTable outputTable = checkAndGetFirst.getOutputTable();
        int numRow = outputTable.getNumRow();
        SparseVector[] sparseVectorArr = new SparseVector[numRow];
        for (int i = 0; i < numRow; i++) {
            sparseVectorArr[i] = VectorUtil.getSparseVector(outputTable.getEntry(i, findColIndexWithAssertAndHint));
        }
        FeatureBundles extract = extract(sparseVectorArr);
        ExclusiveFeatureBundleModelDataConverter exclusiveFeatureBundleModelDataConverter = new ExclusiveFeatureBundleModelDataConverter();
        TableSchema schemaStr2Schema = TableUtil.schemaStr2Schema(extract.getSchemaStr());
        exclusiveFeatureBundleModelDataConverter.efbColNames = schemaStr2Schema.getFieldNames();
        exclusiveFeatureBundleModelDataConverter.efbColTypes = schemaStr2Schema.getFieldTypes();
        RowCollector rowCollector = new RowCollector();
        exclusiveFeatureBundleModelDataConverter.save(extract, rowCollector);
        setOutputTable(new MTable(rowCollector.getRows(), exclusiveFeatureBundleModelDataConverter.getModelSchema()));
        return this;
    }

    public static FeatureBundles extract(SparseVector[] sparseVectorArr) {
        TreeSet treeSet = new TreeSet();
        for (SparseVector sparseVector : sparseVectorArr) {
            int[] indices = sparseVector.getIndices();
            double[] values = sparseVector.getValues();
            for (int i = 0; i < values.length; i++) {
                if (values[i] != 1.0d) {
                    treeSet.add(Integer.valueOf(indices[i]));
                }
            }
        }
        int size = sparseVectorArr[0].size();
        List<int[]> quickCheck = quickCheck(sparseVectorArr, treeSet, size);
        if (null == quickCheck) {
            int length = sparseVectorArr.length;
            ArrayList arrayList = new ArrayList();
            for (SparseVector sparseVector2 : sparseVectorArr) {
                arrayList.add(sparseVector2.getIndices());
            }
            int[] iArr = new int[size];
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                for (int i2 : (int[]) it.next()) {
                    iArr[i2] = iArr[i2] + 1;
                }
            }
            Iterator it2 = treeSet.iterator();
            while (it2.hasNext()) {
                iArr[((Integer) it2.next()).intValue()] = length;
            }
            quickCheck = new ArrayList();
            for (int i3 = 0; i3 < size; i3++) {
                if (length == iArr[i3]) {
                    quickCheck.add(new int[]{i3});
                }
            }
            for (int[] iArr2 : quickCheck) {
                for (int i4 : iArr2) {
                    iArr[i4] = 0;
                }
            }
            for (int sumNumUsed = sumNumUsed(iArr); 0 < sumNumUsed; sumNumUsed = sumNumUsed(iArr)) {
                List<Integer> findNewBundle = findNewBundle(arrayList, Arrays.copyOf(iArr, size), size);
                int[] iArr3 = new int[findNewBundle.size()];
                for (int i5 = 0; i5 < iArr3.length; i5++) {
                    iArr3[i5] = findNewBundle.get(i5).intValue();
                }
                quickCheck.add(iArr3);
                for (int i6 : iArr3) {
                    iArr[i6] = 0;
                }
            }
        }
        return new FeatureBundles(size, quickCheck);
    }

    static List<int[]> quickCheck(SparseVector[] sparseVectorArr, Set<Integer> set, int i) {
        int numberOfValues = sparseVectorArr[0].numberOfValues();
        for (SparseVector sparseVector : sparseVectorArr) {
            if (numberOfValues != sparseVector.numberOfValues()) {
                return null;
            }
        }
        int[] copyOf = Arrays.copyOf(sparseVectorArr[0].getIndices(), numberOfValues);
        int[] copyOf2 = Arrays.copyOf(sparseVectorArr[0].getIndices(), numberOfValues);
        for (SparseVector sparseVector2 : sparseVectorArr) {
            int[] indices = sparseVector2.getIndices();
            for (int i2 = 0; i2 < numberOfValues; i2++) {
                copyOf[i2] = Math.min(copyOf[i2], indices[i2]);
                copyOf2[i2] = Math.max(copyOf2[i2], indices[i2]);
            }
        }
        for (int i3 = 0; i3 < numberOfValues - 1; i3++) {
            if (copyOf2[i3] >= copyOf[i3 + 1]) {
                return null;
            }
        }
        ArrayList arrayList = new ArrayList();
        for (int i4 = 0; i4 < numberOfValues - 1; i4++) {
            int i5 = copyOf[i4 + 1] - copyOf[i4];
            int[] iArr = new int[i5];
            for (int i6 = 0; i6 < i5; i6++) {
                iArr[i6] = copyOf[i4] + i6;
            }
            arrayList.add(iArr);
        }
        for (int i7 = numberOfValues - 1; i7 < numberOfValues; i7++) {
            int i8 = i - copyOf[i7];
            int[] iArr2 = new int[i8];
            for (int i9 = 0; i9 < i8; i9++) {
                iArr2[i9] = copyOf[i7] + i9;
            }
            arrayList.add(iArr2);
        }
        for (Integer num : set) {
            for (int i10 = 0; i10 < numberOfValues; i10++) {
                if (copyOf[i10] <= num.intValue() && num.intValue() <= copyOf2[i10] && copyOf2[i10] > copyOf[i10]) {
                    return null;
                }
            }
        }
        return arrayList;
    }

    static List<Integer> findNewBundle(ArrayList<int[]> arrayList, int[] iArr, int i) {
        ArrayList arrayList2 = new ArrayList();
        int findMaxDegree = findMaxDegree(iArr);
        arrayList2.add(Integer.valueOf(findMaxDegree));
        ArrayList arrayList3 = new ArrayList();
        Iterator<int[]> it = arrayList.iterator();
        while (it.hasNext()) {
            int[] next = it.next();
            if (contains(next, findMaxDegree)) {
                for (int i2 : next) {
                    iArr[i2] = 0;
                }
            } else {
                arrayList3.add(next);
            }
        }
        int i3 = 0;
        int i4 = 0;
        for (int i5 = 0; i5 < i; i5++) {
            if (iArr[i5] > 0) {
                i3 += iArr[i5];
                i4 = Math.max(i4, iArr[i5]);
            }
        }
        if (arrayList3.size() == i3) {
            for (int i6 = 0; i6 < i; i6++) {
                if (iArr[i6] > 0) {
                    arrayList2.add(Integer.valueOf(i6));
                }
            }
        } else if (1 == i4) {
            int i7 = 0;
            for (int i8 = 0; i8 < i; i8++) {
                if (iArr[i8] > 0 && i7 < arrayList3.size()) {
                    arrayList2.add(Integer.valueOf(i8));
                    i7++;
                }
            }
        } else {
            arrayList2.addAll(findNewBundle(arrayList3, iArr, i));
        }
        return arrayList2;
    }

    static int sumNumUsed(int[] iArr) {
        int i = 0;
        for (int i2 : iArr) {
            i += i2;
        }
        return i;
    }

    static int findMaxDegree(int[] iArr) {
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < iArr.length; i3++) {
            if (iArr[i3] > i2) {
                i = i3;
                i2 = iArr[i3];
            }
        }
        return i;
    }

    static boolean contains(int[] iArr, int i) {
        return contains(iArr, 0, iArr.length, i);
    }

    static boolean contains(int[] iArr, int i, int i2, int i3) {
        if (i2 - i >= 10) {
            int i4 = (i + i2) / 2;
            if (iArr[i4] == i3) {
                return true;
            }
            return iArr[i4] > i3 ? contains(iArr, i, i4, i3) : contains(iArr, i4 + 1, i2, i3);
        }
        for (int i5 = i; i5 < i2; i5++) {
            if (iArr[i5] == i3) {
                return true;
            }
        }
        return false;
    }

    @Override // com.alibaba.alink.operator.local.LocalOperator
    public /* bridge */ /* synthetic */ ExclusiveFeatureBundleTrainLocalOp linkFrom(LocalOperator[] localOperatorArr) {
        return linkFrom((LocalOperator<?>[]) localOperatorArr);
    }
}
