package com.alibaba.alink.common.dl;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.Internal;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.dl.coding.ExampleCodingV2;
import com.alibaba.alink.common.dl.data.TFRecordReaderImpl;
import com.alibaba.alink.common.dl.data.TFRecordWriterImpl;
import com.alibaba.alink.common.dl.utils.DLLauncherUtils;
import com.alibaba.alink.common.dl.utils.DLTypeUtils;
import com.alibaba.alink.common.dl.utils.DLUtils;
import com.alibaba.alink.common.dl.utils.DataSetDiskDownloader;
import com.alibaba.alink.common.dl.utils.PythonFileUtils;
import com.alibaba.alink.common.io.plugin.OsType;
import com.alibaba.alink.common.io.plugin.OsUtils;
import com.alibaba.alink.common.io.plugin.ResourcePluginFactory;
import com.alibaba.alink.common.linalg.tensor.Tensor;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetUtil;
import com.alibaba.alink.operator.common.dataproc.FirstReducer;
import com.alibaba.alink.params.dl.DLLauncherParams;
import com.alibaba.flink.ml.tensorflow2.client.DLConfig;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.MapPartitionOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(value = PortType.DATA, desc = PortDesc.DL_BC_DATA, isRepeated = true)})
@OutputPorts(values = {@PortSpec(PortType.DATA)})
@Internal
/* loaded from: input_file:com/alibaba/alink/common/dl/DLLauncherBatchOp.class */
public final class DLLauncherBatchOp extends BatchOperator<DLLauncherBatchOp> implements DLLauncherParams<DLLauncherBatchOp> {
    private static final Logger LOG = LoggerFactory.getLogger(DLLauncherBatchOp.class);
    private final ResourcePluginFactory factory;

    public DLLauncherBatchOp() {
        this(new Params());
    }

    public DLLauncherBatchOp(Params params) {
        super(params);
        this.factory = new ResourcePluginFactory();
    }

    private DLConfig setupDLConfig(TableSchema tableSchema, TableSchema tableSchema2) {
        int intValue = getNumWorkers().intValue();
        int intValue2 = getNumPSs().intValue();
        DLConfig dLConfig = new DLConfig(intValue, intValue2, new HashMap(), (String) null, getEntryFunc(), (String) null);
        DLUtils.setExampleCodingType(dLConfig, tableSchema, tableSchema2);
        if (!StringUtils.isNullOrWhitespaceOnly(getPythonEnv())) {
            DLUtils.safePutProperties(dLConfig, DLConstants.PYTHON_ENV, getPythonEnv());
        } else if (null != getEnvVersion()) {
            DLUtils.safePutProperties(dLConfig, DLConstants.ENV_VERSION, getEnvVersion().name());
        }
        DLUtils.safePutProperties(dLConfig, DLConstants.ENTRY_SCRIPT, getMainScriptFile());
        DLUtils.safePutProperties(dLConfig, DLConstants.ENTRY_FUNC, getEntryFunc());
        DLUtils.safePutProperties(dLConfig, DLConstants.USER_DEFINED_PARAMS, getUserParams());
        DLUtils.safePutProperties(dLConfig, DLConstants.NUM_WORKERS, String.valueOf(intValue));
        DLUtils.safePutProperties(dLConfig, DLConstants.NUM_PSS, String.valueOf(intValue2));
        DLUtils.safePutProperties(dLConfig, "node.idle.timeout", String.valueOf(5000));
        DLUtils.safePutProperties(dLConfig, "sys:ml_runner_class", DLRunner.class.getCanonicalName());
        DLUtils.safePutProperties(dLConfig, "storage_type", "storage_custom");
        DLUtils.safePutProperties(dLConfig, "storage_impl_class", MemoryStorageImplV2.class.getName());
        DLUtils.safePutProperties(dLConfig, "sys:record_reader_class", TFRecordReaderImpl.class.getCanonicalName());
        DLUtils.safePutProperties(dLConfig, "sys:record_writer_class", TFRecordWriterImpl.class.getCanonicalName());
        DLUtils.safePutProperties(dLConfig, "sys:encoding_class", ExampleCodingV2.class.getCanonicalName());
        DLUtils.safePutProperties(dLConfig, "sys:decoding_class", ExampleCodingV2.class.getCanonicalName());
        DLUtils.safePutProperties(dLConfig, DLConstants.INTRA_OP_PARALLELISM, String.valueOf(getIntraOpParallelism()));
        return dLConfig;
    }

    private DataSet<Row> dataSetFirstNPartitionStrictRebalance(DataSet<Row> dataSet, final int i) {
        return dataSet.mapPartition(new RichMapPartitionFunction<Row, Tuple2<Integer, Row>>() { // from class: com.alibaba.alink.common.dl.DLLauncherBatchOp.3
            public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Integer, Row>> collector) throws Exception {
                int i2 = 0;
                Row row = null;
                for (Row row2 : iterable) {
                    row = row2;
                    collector.collect(Tuple2.of(Integer.valueOf(i2), row2));
                    i2++;
                    if (i2 == i) {
                        i2 -= i;
                    }
                }
                if (null == row || i2 == 0) {
                    return;
                }
                while (i2 < i) {
                    collector.collect(Tuple2.of(Integer.valueOf(i2), row));
                    i2++;
                }
            }
        }).partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.common.dl.DLLauncherBatchOp.2
            private static final long serialVersionUID = -44838855219045312L;

            public int partition(Integer num, int i2) {
                return num.intValue() % i2;
            }
        }, 0).map(new MapFunction<Tuple2<Integer, Row>, Row>() { // from class: com.alibaba.alink.common.dl.DLLauncherBatchOp.1
            private static final long serialVersionUID = 5543012093523253627L;

            public Row map(Tuple2<Integer, Row> tuple2) throws Exception {
                return (Row) tuple2.f1;
            }
        });
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public DLLauncherBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        String compressedFileName;
        BatchOperator<?> doubleColumnsToFloat = DLTypeUtils.doubleColumnsToFloat(batchOperatorArr[0]);
        setMLEnvironmentId(doubleColumnsToFloat.getMLEnvironmentId());
        ExecutionEnvironment executionEnvironment = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment();
        Tuple2<Integer, Integer> adjustNumWorkersPSs = DLLauncherUtils.adjustNumWorkersPSs(getNumWorkers(), getNumPSs(), executionEnvironment.getParallelism());
        setNumWorkers((Integer) adjustNumWorkersPSs.f0);
        setNumPSs((Integer) adjustNumWorkersPSs.f1);
        DataSet<Row> dataSet = doubleColumnsToFloat.getDataSet();
        int intValue = getNumWorkers().intValue();
        int intValue2 = getNumPSs().intValue();
        TableSchema schemaStr2Schema = TableUtil.schemaStr2Schema(getOutputSchemaStr());
        DLConfig dLConfig = setupDLConfig(doubleColumnsToFloat.getSchema(), schemaStr2Schema);
        ExternalFilesConfig userFiles = getUserFiles();
        ArrayList arrayList = new ArrayList(userFiles.getFilePaths());
        Map<String, String> fileRenameMap = userFiles.getFileRenameMap();
        String pythonEnv = getPythonEnv();
        if (!StringUtils.isNullOrWhitespaceOnly(pythonEnv)) {
            if (PythonFileUtils.isLocalFile(pythonEnv)) {
                compressedFileName = pythonEnv;
            } else {
                arrayList.add(pythonEnv);
                compressedFileName = PythonFileUtils.getCompressedFileName(pythonEnv);
                DLUtils.safePutProperties(dLConfig, DLConstants.PYTHON_ENV, compressedFileName);
            }
            DLUtils.safePutProperties(dLConfig, DLConstants.PYTHON_ENV, compressedFileName);
        }
        DataSet<Row> dataSetFirstNPartitionStrictRebalance = intValue2 > 0 ? dataSetFirstNPartitionStrictRebalance(dataSet, intValue) : dataSetFirstNPartitionStrictRebalance(dataSet, intValue + intValue2);
        dataSetFirstNPartitionStrictRebalance.output(new DiscardingOutputFormat());
        TableSchema tableSchema = new TableSchema(new String[]{"ip_port"}, new TypeInformation[]{Types.STRING});
        DLClusterMapPartitionFunc dLClusterMapPartitionFunc = new DLClusterMapPartitionFunc(dLConfig.getMlConfig(), doubleColumnsToFloat.getSchema(), schemaStr2Schema, this.factory);
        IterativeDataSet iterate = DataSetUtil.createEmptyDataSet(executionEnvironment, schemaStr2Schema, tableSchema).iterate(2);
        MapPartitionOperator name = dataSetFirstNPartitionStrictRebalance.mapPartition(dLClusterMapPartitionFunc).withBroadcastSet(iterate, DLConstants.IP_PORT_BC_NAME).name("DL_CLUSTER");
        for (int i = 1; i < batchOperatorArr.length; i++) {
            name = (MapPartitionOperator) name.withBroadcastSet(batchOperatorArr[i].getDataSet(), DLConstants.BC_NAME_PREFIX + i);
        }
        name.withBroadcastSet(extractTensorShapes(doubleColumnsToFloat.getDataSet(), doubleColumnsToFloat.getColNames()), DLConstants.BC_NAME_TENSOR_SHAPES).withBroadcastSet(extractTensorTypes(doubleColumnsToFloat.getDataSet(), doubleColumnsToFloat.getColNames()), DLConstants.BC_NAME_TENSOR_TYPES);
        DataSet closeWith = iterate.closeWith(name.withBroadcastSet(DataSetDiskDownloader.downloadFilesWithRename(getMLEnvironmentId(), arrayList, fileRenameMap).getDataSet(), DLConstants.BC_NAME_DOWNLOAD_PATHS));
        setOutput(DataSetUtil.removeLastColumn(closeWith.map(new MapFunction<Row, Row>() { // from class: com.alibaba.alink.common.dl.DLLauncherBatchOp.5
            public Row map(Row row) throws Exception {
                return row;
            }
        }).withBroadcastSet(closeWith.mapPartition(new MapPartitionFunction<Row, Object>() { // from class: com.alibaba.alink.common.dl.DLLauncherBatchOp.4
            public void mapPartition(Iterable<Row> iterable, Collector<Object> collector) throws Exception {
                DLLauncherBatchOp.LOG.info("killing DL tasks------");
                if (OsType.WINDOWS.equals(OsUtils.getSystemType())) {
                    Runtime.getRuntime().exec(new String[]{"cmd.exe", "for /f \"skip=1 tokens=1,2 delims=, \" %a in ('tasklist /fi \" IMAGENAME eq python.exe\" /FO csv ')  do (  wmic process where processid=%b get commandline | findstr startup.py | taskkill /pid %b -f )"}, (String[]) null, (File) null);
                } else {
                    Runtime.getRuntime().exec(new String[]{"/bin/bash", "-c", "ps -ef | grep \"temp_[0-9*]_.*/startup.py\" | awk '{print $2}' | xargs kill -9"}, (String[]) null, (File) null);
                }
            }
        }), "barrier")), schemaStr2Schema);
        return this;
    }

    private DataSet<Map<String, long[]>> extractTensorShapes(DataSet<Row> dataSet, final String[] strArr) {
        return dataSet.reduceGroup(new FirstReducer(1)).flatMap(new FlatMapFunction<Row, Map<String, long[]>>() { // from class: com.alibaba.alink.common.dl.DLLauncherBatchOp.6
            public void flatMap(Row row, Collector<Map<String, long[]>> collector) throws Exception {
                HashMap hashMap = new HashMap();
                for (int i = 0; i < strArr.length; i++) {
                    if (row.getField(i) instanceof Tensor) {
                        hashMap.put(strArr[i], ((Tensor) row.getField(i)).shape());
                    }
                }
                collector.collect(hashMap);
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Map<String, long[]>>) collector);
            }
        });
    }

    private DataSet<Map<String, String>> extractTensorTypes(DataSet<Row> dataSet, final String[] strArr) {
        return dataSet.reduceGroup(new FirstReducer(1)).flatMap(new FlatMapFunction<Row, Map<String, String>>() { // from class: com.alibaba.alink.common.dl.DLLauncherBatchOp.7
            public void flatMap(Row row, Collector<Map<String, String>> collector) {
                HashMap hashMap = new HashMap();
                for (int i = 0; i < strArr.length; i++) {
                    if (row.getField(i) instanceof Tensor) {
                        hashMap.put(strArr[i], ((Tensor) row.getField(i)).getType().name().toLowerCase());
                    }
                }
                collector.collect(hashMap);
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Map<String, String>>) collector);
            }
        });
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ DLLauncherBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
