package com.alibaba.alink.operator.batch.utils;

import com.alibaba.alink.common.io.filesystem.AkUtils2;
import com.alibaba.alink.common.io.filesystem.BaseFileSystem;
import com.alibaba.alink.common.io.filesystem.FilePath;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.io.dummy.DummyOutputFormat;
import com.alibaba.alink.operator.common.io.partition.SinkCollectorCreator;
import com.alibaba.alink.operator.common.io.partition.SourceCollectorCreator;
import com.alibaba.alink.params.io.HasFilePath;
import com.alibaba.alink.params.io.shared.HasPartitionColsDefaultAsNull;
import com.alibaba.alink.params.io.shared.HasPartitions;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.lang.ArrayUtils;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.core.fs.Path;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/batch/utils/DataSetUtil.class */
public class DataSetUtil {
    public static <T> DataSet<Long> count(DataSet<T> dataSet) {
        return dataSet.mapPartition(new MapPartitionFunction<T, Long>() { // from class: com.alibaba.alink.operator.batch.utils.DataSetUtil.2
            private static final long serialVersionUID = 5351290184340971835L;

            public void mapPartition(Iterable<T> iterable, Collector<Long> collector) throws Exception {
                long j = 0;
                for (T t : iterable) {
                    j++;
                }
                collector.collect(Long.valueOf(j));
            }
        }).name("count_dataset").returns(Types.LONG).reduce(new ReduceFunction<Long>() { // from class: com.alibaba.alink.operator.batch.utils.DataSetUtil.1
            private static final long serialVersionUID = -4281590383844098422L;

            public Long reduce(Long l, Long l2) throws Exception {
                return Long.valueOf(l.longValue() + l2.longValue());
            }
        });
    }

    public static BatchOperator count(BatchOperator batchOperator) {
        return BatchOperator.fromTable(DataSetConversionUtil.toTable(batchOperator.getMLEnvironmentId(), (DataSet<Row>) batchOperator.getDataSet().mapPartition(new MapPartitionFunction<Row, Long>() { // from class: com.alibaba.alink.operator.batch.utils.DataSetUtil.4
            private static final long serialVersionUID = -7352692344227251372L;

            public void mapPartition(Iterable<Row> iterable, Collector<Long> collector) throws Exception {
                long j = 0;
                for (Row row : iterable) {
                    j++;
                }
                collector.collect(Long.valueOf(j));
            }
        }).name("count_dataset").reduce(new ReduceFunction<Long>() { // from class: com.alibaba.alink.operator.batch.utils.DataSetUtil.3
            private static final long serialVersionUID = 1164352453904681248L;

            public Long reduce(Long l, Long l2) throws Exception {
                return Long.valueOf(l.longValue() + l2.longValue());
            }
        }).map(new MapFunction<Long, Row>() { // from class: com.alibaba.alink.operator.batch.utils.DataSetUtil.5
            private static final long serialVersionUID = -5014428103964711477L;

            public Row map(Long l) throws Exception {
                return Row.of(new Object[]{l});
            }
        }), new String[]{"num_records"}, (TypeInformation<?>[]) new TypeInformation[]{Types.LONG}));
    }

    public static <T> DataSet<T> empty(DataSet<T> dataSet) {
        return dataSet.flatMap(new FlatMapFunction<T, T>() { // from class: com.alibaba.alink.operator.batch.utils.DataSetUtil.6
            private static final long serialVersionUID = 4385675521544606204L;

            public void flatMap(T t, Collector<T> collector) throws Exception {
            }
        }).returns(dataSet.getType());
    }

    public static <T> DataSet<T> barrier(DataSet<T> dataSet) {
        return dataSet.flatMap(new FlatMapFunction<T, T>() { // from class: com.alibaba.alink.operator.batch.utils.DataSetUtil.7
            private static final long serialVersionUID = -314924938979563398L;

            public void flatMap(T t, Collector<T> collector) throws Exception {
                collector.collect(t);
            }
        }).withBroadcastSet(empty(dataSet), "empty").name("barrier").returns(dataSet.getType());
    }

    public static DataSet<Row> createEmptyDataSet(ExecutionEnvironment executionEnvironment, TableSchema tableSchema, TableSchema tableSchema2) {
        return createEmptyDataSet(executionEnvironment, new TableSchema((String[]) ArrayUtils.addAll(tableSchema.getFieldNames(), tableSchema2.getFieldNames()), (TypeInformation[]) ArrayUtils.addAll(tableSchema.getFieldTypes(), tableSchema2.getFieldTypes())));
    }

    public static DataSet<Row> createEmptyDataSet(ExecutionEnvironment executionEnvironment, TableSchema tableSchema) {
        return executionEnvironment.fromElements(new Integer[]{0}).flatMap(new FlatMapFunction<Integer, Row>() { // from class: com.alibaba.alink.operator.batch.utils.DataSetUtil.8
            private static final long serialVersionUID = 7566732134539040198L;

            public void flatMap(Integer num, Collector<Row> collector) throws Exception {
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Integer) obj, (Collector<Row>) collector);
            }
        }).returns(new RowTypeInfo(tableSchema.getFieldTypes(), tableSchema.getFieldNames()));
    }

    public static DataSet<Row> removeLastColumn(DataSet<Row> dataSet) {
        return dataSet.map(new MapFunction<Row, Row>() { // from class: com.alibaba.alink.operator.batch.utils.DataSetUtil.9
            private static final long serialVersionUID = -6052009263843274262L;

            public Row map(Row row) throws Exception {
                Row row2 = new Row(row.getArity() - 1);
                for (int i = 0; i < row.getArity() - 1; i++) {
                    row2.setField(i, row.getField(i));
                }
                return row2;
            }
        });
    }

    public static DataSet<List<Row>> stack(DataSet<Row> dataSet, final int i) {
        return dataSet.mapPartition(new RichMapPartitionFunction<Row, List<Row>>() { // from class: com.alibaba.alink.operator.batch.utils.DataSetUtil.10
            private static final long serialVersionUID = 4066908302859627800L;

            public void mapPartition(Iterable<Row> iterable, Collector<List<Row>> collector) throws Exception {
                ArrayList arrayList = new ArrayList();
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next());
                    if (arrayList.size() >= i) {
                        collector.collect(arrayList);
                        arrayList.clear();
                    }
                }
                if (arrayList.size() > 0) {
                    collector.collect(arrayList);
                    arrayList.clear();
                }
            }
        });
    }

    public static <T> void linkDummySink(DataSet<T> dataSet) {
        dataSet.output(new DummyOutputFormat());
    }

    public static Tuple2<DataSet<Row>, TableSchema> readFromPartitionBatch(Params params, Long l, SourceCollectorCreator sourceCollectorCreator) throws IOException {
        return readFromPartitionBatch(params, l, sourceCollectorCreator, null);
    }

    public static Tuple2<DataSet<Row>, TableSchema> readFromPartitionBatch(Params params, Long l, final SourceCollectorCreator sourceCollectorCreator, String[] strArr) throws IOException {
        final FilePath deserialize = FilePath.deserialize((String) params.get(HasFilePath.FILE_PATH));
        BatchOperator<?> selectPartitionBatchOp = AkUtils2.selectPartitionBatchOp(l, deserialize, (String) params.get(HasPartitions.PARTITIONS), strArr);
        final String[] colNames = selectPartitionBatchOp.getColNames();
        return Tuple2.of(selectPartitionBatchOp.getDataSet().rebalance().flatMap(new FlatMapFunction<Row, Row>() { // from class: com.alibaba.alink.operator.batch.utils.DataSetUtil.11
            public void flatMap(Row row, Collector<Row> collector) throws Exception {
                Path path = FilePath.this.getPath();
                for (int i = 0; i < row.getArity(); i++) {
                    path = new Path(path, String.format("%s=%s", colNames[i], row.getField(i)));
                }
                sourceCollectorCreator.collect(new FilePath(path, FilePath.this.getFileSystem()), collector);
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Row>) collector);
            }
        }), sourceCollectorCreator.schema());
    }

    public static void partitionAndWriteFile(BatchOperator<?> batchOperator, final SinkCollectorCreator sinkCollectorCreator, Params params) {
        TableSchema schema = batchOperator.getSchema();
        String[] strArr = (String[]) params.get(HasPartitionColsDefaultAsNull.PARTITION_COLS);
        final int[] findColIndicesWithAssertAndHint = TableUtil.findColIndicesWithAssertAndHint(schema, strArr);
        final int[] findColIndices = TableUtil.findColIndices(schema.getFieldNames(), (String[]) org.apache.commons.lang3.ArrayUtils.removeElements(schema.getFieldNames(), strArr));
        final FilePath deserialize = FilePath.deserialize((String) params.get(HasFilePath.FILE_PATH));
        batchOperator.getDataSet().groupBy(strArr).reduceGroup(new GroupReduceFunction<Row, byte[]>() { // from class: com.alibaba.alink.operator.batch.utils.DataSetUtil.12
            public void reduce(Iterable<Row> iterable, Collector<byte[]> collector) throws IOException {
                Path path = FilePath.this.getPath();
                BaseFileSystem<?> fileSystem = FilePath.this.getFileSystem();
                Collector<Row> collector2 = null;
                Path path2 = null;
                for (Row row : iterable) {
                    if (collector2 == null) {
                        path2 = new Path(path.getPath());
                        for (int i : findColIndicesWithAssertAndHint) {
                            path2 = new Path(path2, row.getField(i).toString());
                        }
                        fileSystem.mkdirs(path2);
                        collector2 = sinkCollectorCreator.createCollector(new FilePath(new Path(path2, "0.inprogress"), fileSystem));
                    }
                    collector2.collect(Row.project(row, findColIndices));
                }
                if (collector2 != null) {
                    collector2.close();
                    fileSystem.rename(new Path(path2, "0.inprogress"), new Path(path2, "0"));
                }
            }
        }).output(new DummyOutputFormat());
    }
}
