package com.alibaba.alink.common.dl;

import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.io.plugin.ResourcePluginFactory;
import com.alibaba.flink.ml.cluster.ExecutionMode;
import com.alibaba.flink.ml.cluster.MLConfig;
import java.util.Iterator;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/common/dl/DLClusterMapPartitionFunc.class */
public class DLClusterMapPartitionFunc extends RichMapPartitionFunction<Row, Row> implements ResultTypeQueryable<Row> {
    private final DLFlatMapFunction dlFlatMapFunction;
    private final IpPortFlatMapFunction ipPortFunction = new IpPortFlatMapFunction();
    private final int numWorkers;
    private final int numPSs;
    private final int numOutputFields;
    private transient int stepNo;

    /* loaded from: input_file:com/alibaba/alink/common/dl/DLClusterMapPartitionFunc$IpPortCollector.class */
    private static class IpPortCollector implements Collector<Row> {
        private final Collector<Row> opCollector;
        private final int arity;

        public IpPortCollector(Collector<Row> collector, int i) {
            this.opCollector = collector;
            this.arity = i;
        }

        public void collect(Row row) {
            Row row2 = new Row(this.arity);
            row2.setField(this.arity - 1, row.getField(0));
            this.opCollector.collect(row2);
        }

        public void close() {
        }
    }

    /* loaded from: input_file:com/alibaba/alink/common/dl/DLClusterMapPartitionFunc$TFCollector.class */
    private static class TFCollector implements Collector<Row> {
        private final Collector<Row> opCollector;

        public TFCollector(Collector<Row> collector) {
            this.opCollector = collector;
        }

        public void collect(Row row) {
            Row row2 = new Row(row.getArity() + 1);
            for (int i = 0; i < row.getArity(); i++) {
                row2.setField(i, row.getField(i));
            }
            this.opCollector.collect(row2);
        }

        public void close() {
        }
    }

    public DLClusterMapPartitionFunc(MLConfig mLConfig, TableSchema tableSchema, TableSchema tableSchema2, ResourcePluginFactory resourcePluginFactory) {
        this.dlFlatMapFunction = new DLFlatMapFunction(ExecutionMode.TRAIN, mLConfig, tableSchema, tableSchema2, resourcePluginFactory);
        this.numWorkers = Integer.parseInt((String) mLConfig.getProperties().get(DLConstants.NUM_WORKERS));
        this.numPSs = Integer.parseInt((String) mLConfig.getProperties().get(DLConstants.NUM_PSS));
        this.numOutputFields = tableSchema2.getFieldNames().length;
    }

    private boolean isDummyTask() {
        return getRuntimeContext().getIndexOfThisSubtask() >= this.numWorkers + this.numPSs;
    }

    public void open(Configuration configuration) throws Exception {
        if (isDummyTask()) {
            return;
        }
        this.stepNo = getIterationRuntimeContext().getSuperstepNumber();
        AkPreconditions.checkState(this.stepNo <= 2);
        if (this.stepNo == 1) {
            this.ipPortFunction.open(getRuntimeContext());
        } else if (this.stepNo == 2) {
            this.dlFlatMapFunction.open(getRuntimeContext());
        }
    }

    public void close() {
        if (isDummyTask()) {
            return;
        }
        if (this.stepNo == 1) {
            this.ipPortFunction.close();
        } else if (this.stepNo == 2) {
            this.dlFlatMapFunction.close();
        }
    }

    public void mapPartition(Iterable<Row> iterable, Collector<Row> collector) throws Exception {
        if (isDummyTask()) {
            iterable.forEach(row -> {
            });
            return;
        }
        if (this.stepNo == 1) {
            this.ipPortFunction.flatMap(null, new IpPortCollector(collector, this.numOutputFields + 1));
            iterable.forEach(row2 -> {
            });
        } else if (this.stepNo == 2) {
            TFCollector tFCollector = new TFCollector(collector);
            Iterator<Row> it = iterable.iterator();
            while (it.hasNext()) {
                this.dlFlatMapFunction.flatMap(it.next(), tFCollector);
            }
        }
    }

    public TypeInformation<Row> getProducedType() {
        return this.dlFlatMapFunction.getProducedType();
    }
}
