package com.alibaba.alink.operator.common.feature;

import com.alibaba.alink.operator.common.statistics.ChiSquareTestResult;
import com.alibaba.alink.params.feature.BasedChisqSelectorParams;
import com.google.common.primitives.Ints;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/feature/ChisqSelectorUtil.class */
public class ChisqSelectorUtil {

    /* loaded from: input_file:com/alibaba/alink/operator/common/feature/ChisqSelectorUtil$ChiSquareSelector.class */
    public static class ChiSquareSelector implements MapPartitionFunction<Row, Row> {
        private static final long serialVersionUID = -482962272562482883L;
        private String[] cols;
        private BasedChisqSelectorParams.SelectorType selectorType;
        private int numTopFeatures;
        private double percentile;
        private double fpr;
        private double fdr;
        private double fwe;

        public ChiSquareSelector(String[] strArr, BasedChisqSelectorParams.SelectorType selectorType, int i, double d, double d2, double d3, double d4) {
            this.cols = strArr;
            this.selectorType = selectorType;
            this.numTopFeatures = i;
            this.percentile = d;
            this.fpr = d2;
            this.fdr = d3;
            this.fwe = d4;
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Row> collector) {
            ArrayList arrayList = new ArrayList();
            for (Row row : iterable) {
                arrayList.add(new ChiSquareTestResult(((Double) row.getField(3)).doubleValue(), ((Double) row.getField(1)).doubleValue(), ((Double) row.getField(2)).doubleValue(), row.getField(0).toString()));
            }
            int[] selector = ChisqSelectorUtil.selector(arrayList, this.selectorType, this.numTopFeatures, this.percentile, this.fpr, this.fdr, this.fwe);
            ChisqSelectorModelInfo chisqSelectorModelInfo = new ChisqSelectorModelInfo();
            chisqSelectorModelInfo.chiSqs = (ChiSquareTestResult[]) arrayList.toArray(new ChiSquareTestResult[0]);
            chisqSelectorModelInfo.colNames = this.cols;
            chisqSelectorModelInfo.fwe = this.fwe;
            chisqSelectorModelInfo.fdr = this.fdr;
            chisqSelectorModelInfo.fpr = this.fpr;
            chisqSelectorModelInfo.percentile = this.percentile;
            chisqSelectorModelInfo.numTopFeatures = this.numTopFeatures;
            chisqSelectorModelInfo.selectorType = this.selectorType;
            chisqSelectorModelInfo.siftOutColNames = new String[selector.length];
            for (int i = 0; i < selector.length; i++) {
                chisqSelectorModelInfo.siftOutColNames[i] = this.cols == null ? String.valueOf(selector[i]) : this.cols[selector[i]];
            }
            if (this.cols != null) {
                for (int i2 = 0; i2 < chisqSelectorModelInfo.chiSqs.length; i2++) {
                    chisqSelectorModelInfo.chiSqs[i2].setColName(this.cols[ChisqSelectorUtil.getIdx(chisqSelectorModelInfo.chiSqs[i2])]);
                }
            }
            new ChiSqSelectorModelDataConverter().save(chisqSelectorModelInfo, collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/feature/ChisqSelectorUtil$RowAscComparator.class */
    public static class RowAscComparator implements Comparator<ChiSquareTestResult> {
        private boolean isChisq;

        public RowAscComparator() {
            this.isChisq = false;
        }

        public RowAscComparator(boolean z) {
            this.isChisq = z;
        }

        @Override // java.util.Comparator
        public int compare(ChiSquareTestResult chiSquareTestResult, ChiSquareTestResult chiSquareTestResult2) {
            int compare;
            if (!this.isChisq && (compare = Double.compare(chiSquareTestResult.getP(), chiSquareTestResult2.getP())) != 0) {
                return compare;
            }
            return -Double.compare(chiSquareTestResult.getValue(), chiSquareTestResult2.getValue());
        }
    }

    public static int[] selector(List<ChiSquareTestResult> list, BasedChisqSelectorParams.SelectorType selectorType, int i, double d, double d2, double d3, double d4) {
        int size = list.size();
        ArrayList arrayList = new ArrayList();
        switch (selectorType) {
            case NumTopFeatures:
                list.sort(new RowAscComparator());
                for (int i2 = 0; i2 < i && i2 < size; i2++) {
                    arrayList.add(Integer.valueOf(getIdx(list.get(i2))));
                }
            case PERCENTILE:
                list.sort(new RowAscComparator());
                int i3 = (int) (size * d);
                if (i3 == 0) {
                    i3 = 1;
                }
                for (int i4 = 0; i4 < i3 && i4 < size; i4++) {
                    arrayList.add(Integer.valueOf(getIdx(list.get(i4))));
                }
            case FPR:
                for (ChiSquareTestResult chiSquareTestResult : list) {
                    if (chiSquareTestResult.getValue() < d2) {
                        arrayList.add(Integer.valueOf(getIdx(chiSquareTestResult)));
                    }
                }
                break;
            case FDR:
                list.sort(new RowAscComparator(true));
                int i5 = 0;
                for (int i6 = 0; i6 < size; i6++) {
                    if (list.get(i6).getValue() <= (d3 * (i6 + 1)) / size) {
                        i5 = i6;
                    }
                }
                for (int i7 = 0; i7 <= i5; i7++) {
                    arrayList.add(Integer.valueOf(getIdx(list.get(i7))));
                }
                Collections.sort(arrayList);
                break;
            case FWE:
                for (ChiSquareTestResult chiSquareTestResult2 : list) {
                    if (chiSquareTestResult2.getValue() <= d4 / size) {
                        arrayList.add(Integer.valueOf(getIdx(chiSquareTestResult2)));
                    }
                }
                break;
        }
        return Ints.toArray(arrayList);
    }

    static int getIdx(ChiSquareTestResult chiSquareTestResult) {
        return (int) Math.round(Double.parseDouble(chiSquareTestResult.getColName()));
    }
}
