package com.alibaba.alink.common;

import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.utils.RowCollector;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.local.AlinkLocalSession;
import com.alibaba.alink.operator.local.LocalOperator;
import com.alibaba.alink.params.shared.HasNumThreads;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.apache.commons.lang3.SerializationUtils;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/common/MTableUtil.class */
public class MTableUtil implements Serializable {

    /* loaded from: input_file:com/alibaba/alink/common/MTableUtil$FlatMapFunction.class */
    public interface FlatMapFunction extends Function, Serializable {
        void flatMap(Row row, Collector<Row> collector) throws Exception;

        @Override // java.util.function.Function
        default Object apply(Object obj) {
            return null;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/common/MTableUtil$GenericFlatMapFunction.class */
    public interface GenericFlatMapFunction<I, O> extends Serializable {
        void flatMap(I i, Collector<O> collector) throws Exception;
    }

    /* loaded from: input_file:com/alibaba/alink/common/MTableUtil$GroupFunction.class */
    public interface GroupFunction extends Function, Serializable {
        void calc(List<Row> list, Collector<Row> collector);

        @Override // java.util.function.Function
        default Object apply(Object obj) {
            return null;
        }
    }

    public static List<Object> getColumn(MTable mTable, String str) {
        int findColIndex = TableUtil.findColIndex(mTable.getColNames(), str);
        if (findColIndex == -1) {
            return null;
        }
        List<Row> rows = mTable.getRows();
        ArrayList arrayList = new ArrayList(rows.size());
        for (int i = 0; i < rows.size(); i++) {
            arrayList.add(i, rows.get(i).getField(findColIndex));
        }
        return arrayList;
    }

    public static Map<String, List<Object>> getColumns(MTable mTable) {
        List<Row> rows = mTable.getRows();
        String[] colNames = mTable.getColNames();
        HashMap hashMap = new HashMap(colNames.length);
        for (String str : colNames) {
            int findColIndex = TableUtil.findColIndex(colNames, str);
            ArrayList arrayList = new ArrayList(rows.size());
            for (int i = 0; i < rows.size(); i++) {
                arrayList.add(i, rows.get(i).getField(findColIndex));
            }
            hashMap.put(str, arrayList);
        }
        return hashMap;
    }

    public static MTable getMTable(Object obj) {
        if (obj == null) {
            return null;
        }
        if (obj instanceof MTable) {
            return (MTable) obj;
        }
        if (obj instanceof String) {
            return MTable.fromJson((String) obj);
        }
        throw new AkIllegalDataException("Type must be string or mtable");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MTable copy(MTable mTable) {
        ArrayList arrayList = new ArrayList(mTable.getRows().size());
        Iterator<Row> it = mTable.getRows().iterator();
        while (it.hasNext()) {
            arrayList.add(Row.copy(it.next()));
        }
        return new MTable(arrayList, mTable.getSchemaStr());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MTable select(MTable mTable, String[] strArr) {
        return select(mTable, strArr, TableUtil.findColIndicesWithAssertAndHint(mTable.getSchema(), strArr));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MTable select(MTable mTable, int[] iArr) {
        String[] colNames = mTable.getColNames();
        String[] strArr = new String[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            strArr[i] = colNames[iArr[i]];
        }
        return select(mTable, strArr, iArr);
    }

    private static MTable select(MTable mTable, String[] strArr, int[] iArr) {
        ArrayList arrayList = new ArrayList();
        Iterator<Row> it = mTable.getRows().iterator();
        while (it.hasNext()) {
            arrayList.add(Row.project(it.next(), iArr));
        }
        return new MTable(arrayList, strArr, TableUtil.findColTypesWithAssertAndHint(mTable.getSchema(), strArr));
    }

    public static List<Row> groupFunc(MTable mTable, String[] strArr, GroupFunction groupFunction) {
        int[] findColIndicesWithAssertAndHint = TableUtil.findColIndicesWithAssertAndHint(mTable.getSchema(), strArr);
        TypeComparator createComparator = new RowTypeInfo(mTable.getColTypes(), mTable.getColNames()).createComparator(findColIndicesWithAssertAndHint, new boolean[findColIndicesWithAssertAndHint.length], 0, new ExecutionConfig());
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(mTable.getRows());
        MTable mTable2 = new MTable(arrayList, mTable.getSchemaStr());
        mTable2.orderBy(strArr);
        RowCollector rowCollector = new RowCollector();
        List<Row> rows = mTable2.getRows();
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= rows.size()) {
                return rowCollector.getRows();
            }
            Row row = rows.get(i2);
            int i3 = i2 + 1;
            while (i3 < rows.size() && createComparator.compare(row, rows.get(i3)) == 0) {
                i3++;
            }
            groupFunction.calc(rows.subList(i2, i3), rowCollector);
            i = i3;
        }
    }

    public static List<Row> flatMapWithMultiThreads(MTable mTable, Params params, FlatMapFunction flatMapFunction) {
        int defaultNumThreads = LocalOperator.getDefaultNumThreads();
        if (params.contains(HasNumThreads.NUM_THREADS)) {
            defaultNumThreads = ((Integer) params.get(HasNumThreads.NUM_THREADS)).intValue();
        }
        return flatMapWithMultiThreads(mTable, defaultNumThreads, flatMapFunction);
    }

    public static List<Row> flatMapWithMultiThreads(MTable mTable, int i, FlatMapFunction flatMapFunction) {
        return flatMapWithMultiThreads(mTable.getRows(), i, flatMapFunction);
    }

    public static List<Row> flatMapWithMultiThreads(List<Row> list, Params params, FlatMapFunction flatMapFunction) {
        int defaultNumThreads = LocalOperator.getDefaultNumThreads();
        if (params.contains(HasNumThreads.NUM_THREADS)) {
            defaultNumThreads = ((Integer) params.get(HasNumThreads.NUM_THREADS)).intValue();
        }
        return flatMapWithMultiThreads(list, defaultNumThreads, flatMapFunction);
    }

    public static List<Row> flatMapWithMultiThreads(List<Row> list, int i, FlatMapFunction flatMapFunction) {
        AlinkLocalSession.TaskRunner taskRunner = new AlinkLocalSession.TaskRunner();
        List[] listArr = new List[i];
        int size = list.size();
        byte[] serialize = SerializationUtils.serialize(flatMapFunction);
        for (int i2 = 0; i2 < i; i2++) {
            int startPos = (int) AlinkLocalSession.DISTRIBUTOR.startPos(i2, i, size);
            int localRowCnt = (int) AlinkLocalSession.DISTRIBUTOR.localRowCnt(i2, i, size);
            int i3 = i2;
            if (localRowCnt > 0) {
                taskRunner.submit(() -> {
                    FlatMapFunction flatMapFunction2 = (FlatMapFunction) SerializationUtils.deserialize(serialize);
                    RowCollector rowCollector = new RowCollector();
                    for (int i4 = startPos; i4 < startPos + localRowCnt; i4++) {
                        try {
                            flatMapFunction2.flatMap((Row) list.get(i4), rowCollector);
                        } catch (Exception e) {
                            throw new AkIllegalDataException("FlatMap error on the data : " + ((Row) list.get(i4)).toString(), e);
                        }
                    }
                    listArr[i3] = rowCollector.getRows();
                });
            }
        }
        taskRunner.join();
        ArrayList arrayList = new ArrayList();
        for (int i4 = 0; i4 < i; i4++) {
            if (listArr[i4] != null) {
                arrayList.addAll(listArr[i4]);
            }
        }
        return arrayList;
    }

    public static <I, O> List<O> flatMapWithMultiThreads(List<I> list, Params params, GenericFlatMapFunction<I, O> genericFlatMapFunction) {
        int defaultNumThreads = LocalOperator.getDefaultNumThreads();
        if (params.contains(HasNumThreads.NUM_THREADS)) {
            defaultNumThreads = ((Integer) params.get(HasNumThreads.NUM_THREADS)).intValue();
        }
        return flatMapWithMultiThreads(list, defaultNumThreads, genericFlatMapFunction);
    }

    public static <I, O> List<O> flatMapWithMultiThreads(List<I> list, int i, GenericFlatMapFunction<I, O> genericFlatMapFunction) {
        AlinkLocalSession.TaskRunner taskRunner = new AlinkLocalSession.TaskRunner();
        List[] listArr = new List[i];
        int size = list.size();
        byte[] serialize = SerializationUtils.serialize(genericFlatMapFunction);
        for (int i2 = 0; i2 < i; i2++) {
            int startPos = (int) AlinkLocalSession.DISTRIBUTOR.startPos(i2, i, size);
            int localRowCnt = (int) AlinkLocalSession.DISTRIBUTOR.localRowCnt(i2, i, size);
            int i3 = i2;
            if (localRowCnt > 0) {
                taskRunner.submit(() -> {
                    GenericFlatMapFunction genericFlatMapFunction2 = (GenericFlatMapFunction) SerializationUtils.deserialize(serialize);
                    final ArrayList arrayList = new ArrayList();
                    for (int i4 = startPos; i4 < startPos + localRowCnt; i4++) {
                        try {
                            genericFlatMapFunction2.flatMap(list.get(i4), new Collector<O>() { // from class: com.alibaba.alink.common.MTableUtil.1
                                public void collect(O o) {
                                    arrayList.add(o);
                                }

                                public void close() {
                                }
                            });
                        } catch (Exception e) {
                            throw new AkIllegalDataException("FlatMap error on the data : " + list.get(i4).toString());
                        }
                    }
                    listArr[i3] = arrayList;
                });
            }
        }
        taskRunner.join();
        ArrayList arrayList = new ArrayList();
        for (int i4 = 0; i4 < i; i4++) {
            if (listArr[i4] != null) {
                arrayList.addAll(listArr[i4]);
            }
        }
        return arrayList;
    }
}
