package com.alibaba.alink.operator.batch.statistics;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.io.filesystem.AkUtils;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.RowUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.clustering.dbscan.DbscanConstant;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.ParamUtil;
import com.alibaba.alink.params.statistics.RankingListParams;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "objectCol"), @ParamSelectColumnSpec(name = "statCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "addedCols")})
@NameCn("排行榜")
@NameEn("Ranking List")
/* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/RankingListBatchOp.class */
public final class RankingListBatchOp extends BatchOperator<RankingListBatchOp> implements RankingListParams<RankingListBatchOp> {
    private static final Double PRECISION = Double.valueOf(1.0E-16d);
    private static final long serialVersionUID = 2673682618359897321L;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.alibaba.alink.operator.batch.statistics.RankingListBatchOp$1, reason: invalid class name */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/RankingListBatchOp$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$alibaba$alink$params$statistics$RankingListParams$StatType = new int[RankingListParams.StatType.values().length];

        static {
            try {
                $SwitchMap$com$alibaba$alink$params$statistics$RankingListParams$StatType[RankingListParams.StatType.count.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$statistics$RankingListParams$StatType[RankingListParams.StatType.countTotal.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$statistics$RankingListParams$StatType[RankingListParams.StatType.min.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$statistics$RankingListParams$StatType[RankingListParams.StatType.max.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$statistics$RankingListParams$StatType[RankingListParams.StatType.sum.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$statistics$RankingListParams$StatType[RankingListParams.StatType.mean.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$statistics$RankingListParams$StatType[RankingListParams.StatType.variance.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/RankingListBatchOp$SortByStatCol.class */
    public static class SortByStatCol implements GroupReduceFunction<Row, Row> {
        private static final long serialVersionUID = -8278621481729657224L;
        private int[] colIdx;
        private int groupColIndex;
        private RankingListParams.StatType[] funcs;
        private TypeInformation<?>[] types;
        private boolean hasAdded;
        private boolean isDesending;
        private int topN;

        /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/RankingListBatchOp$SortByStatCol$StatCal.class */
        public class StatCal {
            private TypeInformation<?> type;
            private int colIndex;
            private long count = 0;
            private long countTotal = 0;
            private double sum = Criteria.INVALID_GAIN;
            private double sum2 = Criteria.INVALID_GAIN;
            private double min = Double.POSITIVE_INFINITY;
            private double max = Double.NEGATIVE_INFINITY;

            public StatCal(TypeInformation<?> typeInformation, int i) {
                this.type = typeInformation;
                this.colIndex = i;
            }

            public void add(Row row) {
                this.countTotal++;
                if ((!this.type.equals(AlinkTypes.DOUBLE) && !this.type.equals(AlinkTypes.LONG)) && !this.type.equals(AlinkTypes.INT)) {
                    if (row.getField(this.colIndex) != null) {
                        this.count++;
                    }
                } else if (row.getField(this.colIndex) != null) {
                    double parseDouble = Double.parseDouble(row.getField(this.colIndex).toString());
                    this.count++;
                    this.sum += parseDouble;
                    this.sum2 += parseDouble * parseDouble;
                    this.max = parseDouble > this.max ? parseDouble : this.max;
                    this.min = parseDouble < this.min ? parseDouble : this.min;
                }
            }

            public double calc(RankingListParams.StatType statType) {
                switch (AnonymousClass1.$SwitchMap$com$alibaba$alink$params$statistics$RankingListParams$StatType[statType.ordinal()]) {
                    case 1:
                        return this.count;
                    case 2:
                        return this.countTotal;
                    case 3:
                        return this.count == 0 ? Criteria.INVALID_GAIN : this.min;
                    case 4:
                        return this.count == 0 ? Criteria.INVALID_GAIN : this.max;
                    case 5:
                        return this.sum;
                    case TableUtil.DISPLAY_SIZE /* 6 */:
                        return this.count == 0 ? Criteria.INVALID_GAIN : this.sum / this.count;
                    case 7:
                        return (0 == this.count || 1 == this.count || this.max == this.min) ? Criteria.INVALID_GAIN : Math.max(Criteria.INVALID_GAIN, (this.sum2 - ((this.sum / this.count) * this.sum)) / (this.count - 1));
                    default:
                        throw new RuntimeException("statFunc " + statType + " not support.");
                }
            }
        }

        /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/RankingListBatchOp$SortByStatCol$StatComparator.class */
        static class StatComparator implements Comparator<Map.Entry<Object, Row>> {
            private int index;
            private boolean isDesending;

            private StatComparator(int i, boolean z) {
                this.index = i;
                this.isDesending = z;
            }

            @Override // java.util.Comparator
            public int compare(Map.Entry<Object, Row> entry, Map.Entry<Object, Row> entry2) {
                double doubleValue = ((Double) entry.getValue().getField(this.index)).doubleValue();
                double doubleValue2 = ((Double) entry2.getValue().getField(this.index)).doubleValue();
                if (doubleValue < doubleValue2) {
                    return this.isDesending ? 1 : -1;
                }
                if (Math.abs(doubleValue - doubleValue2) < RankingListBatchOp.PRECISION.doubleValue()) {
                    return 0;
                }
                return this.isDesending ? -1 : 1;
            }

            /* synthetic */ StatComparator(int i, boolean z, AnonymousClass1 anonymousClass1) {
                this(i, z);
            }
        }

        public SortByStatCol(int[] iArr, int i, RankingListParams.StatType[] statTypeArr, TypeInformation<?>[] typeInformationArr, boolean z, int i2) {
            this.colIdx = iArr;
            this.groupColIndex = i;
            this.funcs = statTypeArr;
            this.types = typeInformationArr;
            this.isDesending = z;
            this.topN = i2;
            this.hasAdded = statTypeArr.length != 1;
        }

        public void reduce(Iterable<Row> iterable, Collector<Row> collector) throws Exception {
            int length = this.funcs.length;
            int i = this.colIdx[0];
            HashMap hashMap = new HashMap();
            for (Row row : iterable) {
                Object field = row.getField(i);
                Object field2 = this.groupColIndex != -1 ? row.getField(this.groupColIndex) : null;
                if (hashMap.keySet().contains(field)) {
                    for (int i2 = 0; i2 < length; i2++) {
                        ((StatCal) ((List) ((Tuple2) hashMap.get(field)).f1).get(i2)).add(row);
                    }
                } else {
                    ArrayList arrayList = new ArrayList();
                    for (int i3 = 0; i3 < length; i3++) {
                        StatCal statCal = new StatCal(this.types[this.colIdx[i3 + 1]], this.colIdx[i3 + 1]);
                        statCal.add(row);
                        arrayList.add(statCal);
                    }
                    hashMap.put(field, new Tuple2(field2, arrayList));
                }
            }
            HashMap hashMap2 = new HashMap();
            Object[] objArr = new Object[2];
            for (Object obj : hashMap.keySet()) {
                objArr[0] = ((Tuple2) hashMap.get(obj)).f0;
                objArr[1] = obj;
                hashMap2.put(obj, toRow((List) ((Tuple2) hashMap.get(obj)).f1, objArr, this.funcs));
            }
            ArrayList arrayList2 = new ArrayList(hashMap2.entrySet());
            arrayList2.sort(new StatComparator(this.groupColIndex == -1 ? 1 : 2, this.isDesending, null));
            Iterator it = arrayList2.iterator();
            long j = 1;
            while (true) {
                long j2 = j;
                if (!it.hasNext() || j2 > this.topN) {
                    return;
                }
                collector.collect(RowUtil.merge((Row) ((Map.Entry) it.next()).getValue(), Long.valueOf(j2)));
                j = j2 + 1;
            }
        }

        private Row toRow(List<StatCal> list, Object[] objArr, RankingListParams.StatType[] statTypeArr) {
            if (objArr == null || objArr.length < 1) {
                throw new RuntimeException("No Object col info.");
            }
            int length = objArr[0] == null ? objArr.length - 1 : objArr.length;
            int size = list.size() + length;
            Row row = new Row(size);
            if (objArr[0] == null) {
                row.setField(0, objArr[1]);
            } else {
                row.setField(0, objArr[0]);
                row.setField(1, objArr[1]);
            }
            for (int i = length; i < size; i++) {
                row.setField(i, Double.valueOf(list.get(i - length).calc(statTypeArr[i - length])));
            }
            return row;
        }
    }

    public RankingListBatchOp() {
        super(null);
    }

    public RankingListBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public RankingListBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        TypeInformation typeInformation;
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String groupCol = getGroupCol();
        String[] groupValues = getGroupValues();
        String objectCol = getObjectCol();
        String statCol = getStatCol();
        RankingListParams.StatType statType = getStatType();
        String[] addedCols = getAddedCols();
        String[] addedStatTypes = getAddedStatTypes();
        Boolean isDescending = getIsDescending();
        int intValue = ((Integer) super.getParams().get(TOP_N)).intValue();
        if (checkAndGetFirst.getColNames().length == 0) {
            throw new RuntimeException("table col num must be larger than 0.");
        }
        if (objectCol == null || objectCol.isEmpty()) {
            throw new RuntimeException("objectCol must be set.");
        }
        if ((statCol == null || statCol.isEmpty()) && statType != RankingListParams.StatType.count) {
            throw new RuntimeException("if stat col is null, then statFunc must be count.");
        }
        if (statCol == null || statCol.isEmpty()) {
            statCol = checkAndGetFirst.getColNames()[0];
        }
        if ((addedCols == null && addedStatTypes != null) || (addedCols != null && addedStatTypes == null)) {
            throw new RuntimeException("addedCols and addedStatFuncs length must be same.");
        }
        if (addedCols != null && addedStatTypes != null && addedCols.length != addedStatTypes.length) {
            throw new RuntimeException("addedCols and addedStatFuncs length must be same.");
        }
        if (groupCol != null && !groupCol.isEmpty() && (groupValues == null || groupValues.length == 0)) {
            throw new RuntimeException("values must be set.");
        }
        TableSchema schema = checkAndGetFirst.getSchema();
        if (statCol != null && !statCol.isEmpty() && (typeInformation = (TypeInformation) schema.getFieldType(statCol).get()) != AlinkTypes.INT && typeInformation != AlinkTypes.LONG && typeInformation != AlinkTypes.DOUBLE && statType != RankingListParams.StatType.count) {
            throw new RuntimeException("only support count when type not double and long.");
        }
        if (addedCols != null && addedCols.length != 0) {
            TypeInformation[] typeInformationArr = new TypeInformation[addedCols.length];
            int length = addedCols.length;
            for (int i = 0; i < length; i++) {
                TableUtil.assertSelectedColExist(checkAndGetFirst.getColNames(), addedCols[i]);
                TypeInformation typeInformation2 = (TypeInformation) schema.getFieldType(addedCols[i]).get();
                if (typeInformation2 != AlinkTypes.INT && typeInformation2 != AlinkTypes.LONG && typeInformation2 != AlinkTypes.DOUBLE && !addedStatTypes[i].equals(DbscanConstant.COUNT)) {
                    throw new RuntimeException("only support count when type not double and long.");
                }
                typeInformationArr[i] = typeInformation2;
            }
        }
        TypeInformation typeInformation3 = null;
        if (groupCol != null && !groupCol.isEmpty()) {
            typeInformation3 = (TypeInformation) schema.getFieldType(groupCol).get();
            if (typeInformation3 != AlinkTypes.STRING) {
                throw new RuntimeException("group col must be string.");
            }
        }
        if (objectCol == null || objectCol.isEmpty()) {
            throw new RuntimeException("object col must exist.");
        }
        TypeInformation typeInformation4 = null;
        if (objectCol != null && !objectCol.isEmpty()) {
            typeInformation4 = (TypeInformation) schema.getFieldType(objectCol).get();
            if (typeInformation4 != AlinkTypes.STRING && typeInformation4 != AlinkTypes.LONG && typeInformation4 != AlinkTypes.INT) {
                throw new RuntimeException("objectCol  must be string or bigint.");
            }
        }
        int i2 = -1;
        if (groupCol != null) {
            StringBuilder sb = new StringBuilder(groupCol);
            sb.append(AkUtils.COLUMN_SPLIT_TAG).append("'").append(groupValues[0]).append("'");
            for (int i3 = 1; i3 < groupValues.length; i3++) {
                sb.append(" or ").append(groupCol).append(AkUtils.COLUMN_SPLIT_TAG).append("'").append(groupValues[i3]).append("'");
            }
            String sb2 = sb.toString();
            if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println("filter: " + sb2);
            }
            checkAndGetFirst = checkAndGetFirst.filter(sb2);
            i2 = TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getColNames(), groupCol);
        }
        int length2 = addedCols == null ? 0 : addedCols.length;
        int[] iArr = new int[length2 + 2];
        iArr[0] = TableUtil.findColIndex(checkAndGetFirst.getColNames(), objectCol);
        iArr[1] = TableUtil.findColIndex(checkAndGetFirst.getColNames(), statCol);
        for (int i4 = 0; i4 < length2; i4++) {
            iArr[i4 + 2] = TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getColNames(), addedCols[i4]);
        }
        RankingListParams.StatType[] statTypeArr = {statType};
        if (length2 > 0) {
            statTypeArr = new RankingListParams.StatType[1 + length2];
            statTypeArr[0] = statType;
            for (int i5 = 0; i5 < addedStatTypes.length; i5++) {
                statTypeArr[1 + i5] = (RankingListParams.StatType) ParamUtil.searchEnum(RankingListParams.STAT_TYPE, addedStatTypes[i5]);
            }
        }
        TypeInformation[] typeInformationArr2 = new TypeInformation[length2];
        for (int i6 = 0; i6 < length2; i6++) {
            typeInformationArr2[i6] = AlinkTypes.DOUBLE;
        }
        if (groupCol == null) {
            String[] strArr = {objectCol, statCol};
            TypeInformation[] typeInformationArr3 = {typeInformation4, AlinkTypes.DOUBLE};
            if (addedCols != null && length2 > 0) {
                strArr = (String[]) ArrayUtils.addAll(strArr, addedCols);
                typeInformationArr3 = (TypeInformation[]) ArrayUtils.addAll(typeInformationArr3, typeInformationArr2);
            }
            setOutput(checkAndGetFirst.getDataSet().reduceGroup(new SortByStatCol(iArr, i2, statTypeArr, checkAndGetFirst.getColTypes(), isDescending.booleanValue(), intValue)), (String[]) ArrayUtils.add(strArr, "rank"), (TypeInformation[]) ArrayUtils.add(typeInformationArr3, AlinkTypes.LONG));
        } else {
            String[] strArr2 = {groupCol, objectCol, statCol};
            TypeInformation[] typeInformationArr4 = {typeInformation3, typeInformation4, AlinkTypes.DOUBLE};
            if (addedCols != null && length2 > 0) {
                strArr2 = (String[]) ArrayUtils.addAll(strArr2, addedCols);
                typeInformationArr4 = (TypeInformation[]) ArrayUtils.addAll(typeInformationArr4, typeInformationArr2);
            }
            setOutput(checkAndGetFirst.getDataSet().groupBy(new int[]{i2}).reduceGroup(new SortByStatCol(iArr, i2, statTypeArr, checkAndGetFirst.getColTypes(), isDescending.booleanValue(), intValue)), (String[]) ArrayUtils.add(strArr2, "rank"), (TypeInformation[]) ArrayUtils.add(typeInformationArr4, AlinkTypes.LONG));
        }
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ RankingListBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
