package com.alibaba.alink.operator.common.statistics.basicstatistic;

import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.exceptions.AkIllegalStateException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.operator.batch.dataproc.AppendIdBatchOp;
import com.alibaba.alink.operator.common.dataproc.SortUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/statistics/basicstatistic/SpearmanCorrelation.class */
public class SpearmanCorrelation {

    /* loaded from: input_file:com/alibaba/alink/operator/common/statistics/basicstatistic/SpearmanCorrelation$CombineRank.class */
    public static class CombineRank extends RichGroupReduceFunction<Tuple3<Integer, Long, Long>, Row> {
        private static final long serialVersionUID = 5398868762882317000L;
        private Boolean outputIsVector;
        private Boolean hasRowId;

        CombineRank(boolean z, boolean z2) {
            this.outputIsVector = Boolean.valueOf(z);
            this.hasRowId = Boolean.valueOf(z2);
        }

        public void reduce(Iterable<Tuple3<Integer, Long, Long>> iterable, Collector<Row> collector) throws Exception {
            Row row;
            int i = -1;
            long j = -1;
            ArrayList arrayList = new ArrayList();
            for (Tuple3<Integer, Long, Long> tuple3 : iterable) {
                if (i < ((Integer) tuple3.f0).intValue()) {
                    i = ((Integer) tuple3.f0).intValue();
                }
                arrayList.add(tuple3);
                j = ((Long) tuple3.f1).longValue();
            }
            int i2 = i + 1;
            double[] dArr = new double[i2];
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                dArr[((Integer) ((Tuple3) it.next()).f0).intValue()] = ((Long) r0.f2).longValue();
            }
            if (this.outputIsVector.booleanValue()) {
                row = new Row(1);
                row.setField(0, VectorUtil.toString(new DenseVector(dArr)));
            } else {
                row = new Row(i2);
                for (int i3 = 0; i3 < i2; i3++) {
                    row.setField(i3, Double.valueOf(dArr[i3]));
                }
            }
            if (this.hasRowId.booleanValue()) {
                int arity = row.getArity();
                Row row2 = new Row(arity + 1);
                for (int i4 = 0; i4 < arity; i4++) {
                    row2.setField(i4, row.getField(i4));
                }
                row2.setField(arity, Long.valueOf(j));
                row = row2;
            }
            collector.collect(row);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/statistics/basicstatistic/SpearmanCorrelation$MultiRank.class */
    public static class MultiRank extends RichGroupReduceFunction<Tuple2<Integer, Row>, Tuple3<Integer, Long, Long>> {
        private static final long serialVersionUID = 3455524204690570117L;
        private List<Tuple2<Integer, Long>> counts;
        private long totalCnt = 0;

        MultiRank() {
        }

        public void open(Configuration configuration) throws Exception {
            this.counts = (List) getRuntimeContext().getBroadcastVariableWithInitializer("counts", new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() { // from class: com.alibaba.alink.operator.common.statistics.basicstatistic.SpearmanCorrelation.MultiRank.1
                public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> iterable) {
                    ArrayList arrayList = new ArrayList();
                    Iterator<Tuple2<Integer, Long>> it = iterable.iterator();
                    while (it.hasNext()) {
                        arrayList.add(it.next());
                    }
                    arrayList.sort(Comparator.comparing(tuple2 -> {
                        return (Integer) tuple2.f0;
                    }));
                    return arrayList;
                }

                /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                public /* bridge */ /* synthetic */ Object m558initializeBroadcastVariable(Iterable iterable) {
                    return initializeBroadcastVariable((Iterable<Tuple2<Integer, Long>>) iterable);
                }
            });
            this.totalCnt = ((Long) getRuntimeContext().getBroadcastVariableWithInitializer("totalCnt", new BroadcastVariableInitializer<Long, Long>() { // from class: com.alibaba.alink.operator.common.statistics.basicstatistic.SpearmanCorrelation.MultiRank.2
                public Long initializeBroadcastVariable(Iterable<Long> iterable) {
                    return iterable.iterator().next();
                }

                /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                public /* bridge */ /* synthetic */ Object m559initializeBroadcastVariable(Iterable iterable) {
                    return initializeBroadcastVariable((Iterable<Long>) iterable);
                }
            })).longValue();
        }

        private int findCurCount(int i) {
            for (Tuple2<Integer, Long> tuple2 : this.counts) {
                if (((Integer) tuple2.f0).intValue() == i) {
                    return ((Long) tuple2.f1).intValue();
                }
            }
            throw new AkIllegalOperatorParameterException("Error key. key: " + i);
        }

        public void reduce(Iterable<Tuple2<Integer, Row>> iterable, Collector<Tuple3<Integer, Long, Long>> collector) {
            Row[] rowArr = null;
            int i = -1;
            int i2 = 0;
            for (Tuple2<Integer, Row> tuple2 : iterable) {
                i = ((Integer) tuple2.f0).intValue();
                if (rowArr == null) {
                    rowArr = new Row[findCurCount(i)];
                }
                rowArr[i2] = (Row) tuple2.f1;
                i2++;
            }
            if (rowArr == null) {
                return;
            }
            SortUtils.RowComparator rowComparator = new SortUtils.RowComparator(0);
            SortUtils.ComparableComparator comparableComparator = new SortUtils.ComparableComparator();
            Arrays.sort(rowArr, rowComparator);
            long j = 0;
            for (Tuple2<Integer, Long> tuple22 : this.counts) {
                int intValue = ((Integer) tuple22.f0).intValue();
                if (intValue == i) {
                    break;
                } else {
                    if (intValue > i) {
                        throw new AkIllegalStateException("Error curId: " + intValue + ". id: " + i);
                    }
                    j += ((Long) tuple22.f1).longValue();
                }
            }
            long j2 = j;
            long j3 = 0;
            Object obj = null;
            Integer num = null;
            for (Row row : rowArr) {
                Tuple3 tuple3 = new Tuple3();
                TripleComparable tripleComparable = (TripleComparable) row.getField(0);
                tuple3.f0 = (Integer) tripleComparable.first;
                tuple3.f1 = (Long) tripleComparable.second;
                if (num == null) {
                    num = (Integer) tuple3.f0;
                    j3 = (j2 % this.totalCnt) + 1;
                    obj = tripleComparable.third;
                } else if (num.equals(tuple3.f0)) {
                    Object obj2 = tripleComparable.third;
                    if (1 != comparableComparator.compare(obj, obj2)) {
                        obj = obj2;
                        j3++;
                    }
                } else {
                    j3 = 1;
                    obj = tripleComparable.third;
                    num = (Integer) tuple3.f0;
                }
                tuple3.f2 = Long.valueOf(j3);
                collector.collect(tuple3);
                j2++;
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/statistics/basicstatistic/SpearmanCorrelation$TripleComparable.class */
    public static class TripleComparable<T0 extends Comparable, T1 extends Number, T2 extends Number> implements Comparable<TripleComparable> {
        static final SortUtils.ComparableComparator OBJECT_COMPARATOR = new SortUtils.ComparableComparator();
        public T0 first;
        public T1 second;
        public T2 third;

        TripleComparable(T0 t0, T1 t1, T2 t2) {
            this.first = t0;
            this.second = t1;
            this.third = t2;
        }

        @Override // java.lang.Comparable
        public int compareTo(TripleComparable tripleComparable) {
            int compareTo = this.first.compareTo(tripleComparable.first);
            return compareTo == 0 ? OBJECT_COMPARATOR.compare(this.third, tripleComparable.third) : compareTo;
        }
    }

    public static DataSet<Row> calcRank(DataSet<Row> dataSet, Boolean bool) {
        return calcRank(dataSet, bool, false);
    }

    static DataSet<Row> calcRank(DataSet<Row> dataSet, Boolean bool, boolean z) {
        MapOperator map = dataSet.map(new AppendIdBatchOp.AppendIdMapper());
        MapOperator map2 = DataSetUtils.countElementsPerPartition(map).sum(1).map(new MapFunction<Tuple2<Integer, Long>, Long>() { // from class: com.alibaba.alink.operator.common.statistics.basicstatistic.SpearmanCorrelation.1
            private static final long serialVersionUID = -8507632108475760763L;

            public Long map(Tuple2<Integer, Long> tuple2) {
                return (Long) tuple2.f1;
            }
        });
        Tuple2<DataSet<Tuple2<Integer, Row>>, DataSet<Tuple2<Integer, Long>>> pSort = SortUtils.pSort(map.flatMap(new FlatMapFunction<Row, Row>() { // from class: com.alibaba.alink.operator.common.statistics.basicstatistic.SpearmanCorrelation.2
            private static final long serialVersionUID = 8019743706688433562L;

            public void flatMap(Row row, Collector<Row> collector) {
                long longValue = ((Long) row.getField(row.getArity() - 1)).longValue();
                for (int i = 0; i < row.getArity() - 1; i++) {
                    collector.collect(Row.of(new Object[]{new TripleComparable(Integer.valueOf(i), Long.valueOf(longValue), (Number) row.getField(i))}));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Row>) collector);
            }
        }), 0);
        return ((DataSet) pSort.f0).groupBy(new int[]{0}).withPartitioner(new SortUtils.AvgPartition()).reduceGroup(new MultiRank()).withBroadcastSet((DataSet) pSort.f1, "counts").withBroadcastSet(map2, "totalCnt").groupBy(new int[]{1}).withPartitioner(new SortUtils.AvgLongPartitioner()).reduceGroup(new CombineRank(bool.booleanValue(), z));
    }
}
