package com.alibaba.alink.common.dl;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.dl.DLEnvConfig;
import com.alibaba.alink.common.dl.utils.DLUtils;
import com.alibaba.alink.common.dl.utils.DataSetDiskDownloader;
import com.alibaba.alink.common.dl.utils.PythonFileUtils;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.io.filesystem.copy.csv.CsvInputFormat;
import com.alibaba.alink.common.io.plugin.ResourcePluginFactory;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.flink.ml.cluster.ExecutionMode;
import com.alibaba.flink.ml.cluster.MLConfig;
import com.alibaba.flink.ml.cluster.node.MLContext;
import com.alibaba.flink.ml.cluster.role.BaseRole;
import com.alibaba.flink.ml.cluster.rpc.NodeServer;
import com.alibaba.flink.ml.data.DataExchange;
import com.alibaba.flink.ml.util.IpHostUtil;
import java.io.BufferedWriter;
import java.io.Closeable;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.FutureTask;
import org.apache.commons.io.FileUtils;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/common/dl/DLFlatMapFunction.class */
public class DLFlatMapFunction implements Closeable, Serializable {
    private final MLConfig config;
    private TypeInformation<Row> outTI;
    private final int numOutputFields;
    private MLContext mlContext;
    private FutureTask<Void> serverFuture;
    private final ResourcePluginFactory factory;
    private final ExecutionMode mode;
    private volatile Collector<Row> collector = null;
    private transient DataExchange<Row, Row> dataExchange;
    private static final Logger LOG = LoggerFactory.getLogger(DLFlatMapFunction.class);

    public DLFlatMapFunction(ExecutionMode executionMode, MLConfig mLConfig, TableSchema tableSchema, TableSchema tableSchema2, ResourcePluginFactory resourcePluginFactory) {
        this.factory = resourcePluginFactory;
        this.mode = executionMode;
        this.config = mLConfig;
        this.outTI = new RowTypeInfo(tableSchema.getFieldTypes(), tableSchema.getFieldNames());
        this.outTI = new RowTypeInfo(tableSchema2.getFieldTypes(), tableSchema2.getFieldNames());
        this.numOutputFields = tableSchema2.getFieldNames().length;
    }

    public static void prepareBroadcastData(String str, RuntimeContext runtimeContext, MLContext mLContext) {
        for (int i = 1; i < Integer.MAX_VALUE && runtimeContext.hasBroadcastVariable(DLConstants.BC_NAME_PREFIX + i); i++) {
            List<Row> broadcastVariable = runtimeContext.getBroadcastVariable(DLConstants.BC_NAME_PREFIX + i);
            String str2 = str + File.separator + "bc_data_" + i;
            try {
                FileWriter fileWriter = new FileWriter(str2);
                Throwable th = null;
                try {
                    try {
                        BufferedWriter bufferedWriter = new BufferedWriter(fileWriter);
                        Throwable th2 = null;
                        try {
                            try {
                                for (Row row : broadcastVariable) {
                                    StringBuilder sb = new StringBuilder();
                                    for (int i2 = 0; i2 < row.getArity(); i2++) {
                                        if (i2 > 0) {
                                            sb.append(" ");
                                        }
                                        sb.append(row.getField(i2));
                                    }
                                    sb.append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
                                    bufferedWriter.write(sb.toString());
                                }
                                if (bufferedWriter != null) {
                                    if (0 != 0) {
                                        try {
                                            bufferedWriter.close();
                                        } catch (Throwable th3) {
                                            th2.addSuppressed(th3);
                                        }
                                    } else {
                                        bufferedWriter.close();
                                    }
                                }
                                if (fileWriter != null) {
                                    if (0 != 0) {
                                        try {
                                            fileWriter.close();
                                        } catch (Throwable th4) {
                                            th.addSuppressed(th4);
                                        }
                                    } else {
                                        fileWriter.close();
                                    }
                                }
                                LOG.info("Succ in writing bc data to {}", str2);
                                DLUtils.safePutProperties(mLContext, DLConstants.BC_NAME_PREFIX + i, str2);
                            } catch (Throwable th5) {
                                th2 = th5;
                                throw th5;
                            }
                        } catch (Throwable th6) {
                            if (bufferedWriter != null) {
                                if (th2 != null) {
                                    try {
                                        bufferedWriter.close();
                                    } catch (Throwable th7) {
                                        th2.addSuppressed(th7);
                                    }
                                } else {
                                    bufferedWriter.close();
                                }
                            }
                            throw th6;
                        }
                    } catch (Throwable th8) {
                        th = th8;
                        throw th8;
                    }
                } finally {
                }
            } catch (IOException e) {
                throw new AkUnclassifiedErrorException("Fail to write broadcast data to local disk.");
            }
        }
    }

    public void open(RuntimeContext runtimeContext) throws Exception {
        int parseInt = Integer.parseInt((String) this.config.getProperties().get(DLConstants.NUM_WORKERS));
        int parseInt2 = Integer.parseInt((String) this.config.getProperties().get(DLConstants.NUM_PSS));
        List broadcastVariable = runtimeContext.getBroadcastVariable(DLConstants.IP_PORT_BC_NAME);
        AkPreconditions.checkState(broadcastVariable.size() == parseInt + parseInt2, "Some IPs and ports are missing.");
        ArrayList<Tuple3> arrayList = new ArrayList(broadcastVariable.size());
        broadcastVariable.forEach(row -> {
            String[] split = ((String) row.getField(this.numOutputFields)).split("-");
            arrayList.add(Tuple3.of(Integer.valueOf(Integer.parseInt(split[0])), split[1], Integer.valueOf(Integer.parseInt(split[2]))));
        });
        int indexOfThisSubtask = runtimeContext.getIndexOfThisSubtask();
        Tuple2<BaseRole, Integer> roleAndIndex = DLRunner.getRoleAndIndex(indexOfThisSubtask, parseInt);
        String downloadPath = DataSetDiskDownloader.getDownloadPath((String[]) runtimeContext.getBroadcastVariable(DLConstants.BC_NAME_DOWNLOAD_PATHS).stream().map(obj -> {
            return (String) ((Row) obj).getField(0);
        }).toArray(i -> {
            return new String[i];
        }));
        LOG.info("Worker {} uses download path: {}", Integer.valueOf(runtimeContext.getIndexOfThisSubtask()), downloadPath);
        if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
            System.out.printf("worker %d use download path: %s%n", Integer.valueOf(runtimeContext.getIndexOfThisSubtask()), downloadPath);
        }
        Map properties = this.config.getProperties();
        properties.put("current_work_dir", downloadPath);
        properties.put(DLConstants.WORK_DIR, downloadPath);
        this.mlContext = new MLContext(this.mode, this.config, ((BaseRole) roleAndIndex.f0).name(), ((Integer) roleAndIndex.f1).intValue(), this.config.getEnvPath(), (Map) null);
        String str = (String) properties.get(DLConstants.PYTHON_ENV);
        if (StringUtils.isNullOrWhitespaceOnly(str)) {
            DLEnvConfig.Version valueOf = DLEnvConfig.Version.valueOf((String) properties.get(DLConstants.ENV_VERSION));
            LOG.info(String.format("Use pythonEnv from plugin: %s", valueOf));
            properties.put("virtual_env_dir", DLEnvConfig.getDefaultPythonEnv(this.factory, valueOf).substring("file://".length()));
        } else if (PythonFileUtils.isLocalFile(str)) {
            LOG.info(String.format("Use pythonEnv from local file: %s", str));
            properties.put("virtual_env_dir", str.substring("file://".length()));
        } else {
            LOG.info(String.format("Use pythonEnv from local file: %s", str));
            properties.put("virtual_env_dir", new File(downloadPath, str).getAbsolutePath());
        }
        String fileName = PythonFileUtils.getFileName((String) properties.get(DLConstants.ENTRY_SCRIPT));
        this.mlContext.setPythonDir(new File(downloadPath).toPath());
        this.mlContext.setPythonFiles(new String[]{new File(downloadPath, fileName).getAbsolutePath()});
        if (runtimeContext.hasBroadcastVariable(DLConstants.BC_NAME_TENSOR_SHAPES)) {
            Map map = (Map) runtimeContext.getBroadcastVariable(DLConstants.BC_NAME_TENSOR_SHAPES).get(0);
            File file = new File(downloadPath, "tensor_shapes.txt");
            FileUtils.write(file, JsonConverter.toJson(map));
            LOG.info("Succ in writing tensor shape map to {}", file.getAbsolutePath());
        }
        if (runtimeContext.hasBroadcastVariable(DLConstants.BC_NAME_TENSOR_TYPES)) {
            Map map2 = (Map) runtimeContext.getBroadcastVariable(DLConstants.BC_NAME_TENSOR_TYPES).get(0);
            File file2 = new File(downloadPath, "tensor_types.txt");
            FileUtils.write(file2, JsonConverter.toJson(map2));
            LOG.info("Succ in writing tensor shape map to {}", file2.getAbsolutePath());
        }
        prepareBroadcastData(downloadPath, runtimeContext, this.mlContext);
        String[] strArr = new String[arrayList.size()];
        int[] iArr = new int[arrayList.size()];
        for (Tuple3 tuple3 : arrayList) {
            int intValue = ((Integer) tuple3.f0).intValue();
            strArr[intValue] = (String) tuple3.f1;
            if (indexOfThisSubtask == intValue) {
                AkPreconditions.checkState(strArr[intValue].equals(IpHostUtil.getIpAddress()), "task allocation changed");
            }
            iArr[intValue] = ((Integer) tuple3.f2).intValue();
        }
        DLUtils.safePutProperties(this.mlContext, DLRunner.IPS, JsonConverter.toJson(strArr));
        DLUtils.safePutProperties(this.mlContext, DLRunner.PORTS, JsonConverter.toJson(iArr));
        this.dataExchange = new DataExchange<>(this.mlContext);
        try {
            this.serverFuture = new FutureTask<>(new NodeServer(this.mlContext, ((BaseRole) roleAndIndex.f0).name()), null);
            Thread thread = new Thread(this.serverFuture);
            thread.setDaemon(true);
            thread.setName("NodeServer_" + this.mlContext.getIdentity());
            thread.start();
            LOG.info("start: {}, index: {}", this.mlContext.getRoleName(), Integer.valueOf(this.mlContext.getIndex()));
            if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println("start:" + this.mlContext.getRoleName() + " index:" + this.mlContext.getIndex());
            }
        } catch (Exception e) {
            LOG.error("Fail to start node service.", e);
            throw new IOException(e.getMessage());
        }
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        if (null == this.mlContext) {
            return;
        }
        if (this.mlContext.getOutputQueue() != null) {
            this.mlContext.getOutputQueue().markFinished();
        }
        try {
            if ("ps".equals(this.mlContext.getRoleName())) {
                LOG.info("PS job return");
                return;
            }
            try {
                drainRead(this.collector, true);
                if (this.serverFuture != null && !this.serverFuture.isCancelled()) {
                    this.serverFuture.get();
                }
                this.serverFuture = null;
                long readRecords = this.dataExchange.getReadRecords();
                int i = 0;
                if (this.mlContext != null) {
                    i = this.mlContext.getFailNum();
                    try {
                        this.mlContext.close();
                    } catch (IOException e) {
                        LOG.error("Fail to close mlContext.", e);
                    }
                    this.mlContext = null;
                }
                if (i > 0) {
                    throw new AkUnclassifiedErrorException("Python script run failed, please check TaskManager logs.");
                }
                LOG.info("Records output: " + readRecords);
            } catch (InterruptedException e2) {
                LOG.error("Interrupted waiting for server join {}.", e2.getMessage());
                this.serverFuture.cancel(true);
                this.serverFuture = null;
                long readRecords2 = this.dataExchange.getReadRecords();
                int i2 = 0;
                if (this.mlContext != null) {
                    i2 = this.mlContext.getFailNum();
                    try {
                        this.mlContext.close();
                    } catch (IOException e3) {
                        LOG.error("Fail to close mlContext.", e3);
                    }
                    this.mlContext = null;
                }
                if (i2 > 0) {
                    throw new AkUnclassifiedErrorException("Python script run failed, please check TaskManager logs.");
                }
                LOG.info("Records output: " + readRecords2);
            } catch (ExecutionException e4) {
                LOG.error(this.mlContext.getIdentity() + " node server failed");
                throw new AkUnclassifiedErrorException(this.mlContext.getIdentity() + " node server failed", e4);
            } catch (Throwable th) {
                throw new AkUnclassifiedErrorException("Exception thrown.", th);
            }
        } catch (Throwable th2) {
            this.serverFuture = null;
            long readRecords3 = this.dataExchange.getReadRecords();
            int i3 = 0;
            if (this.mlContext != null) {
                i3 = this.mlContext.getFailNum();
                try {
                    this.mlContext.close();
                } catch (IOException e5) {
                    LOG.error("Fail to close mlContext.", e5);
                }
                this.mlContext = null;
            }
            if (i3 > 0) {
                throw new AkUnclassifiedErrorException("Python script run failed, please check TaskManager logs.");
            }
            LOG.info("Records output: " + readRecords3);
            throw th2;
        }
    }

    public void flatMap(Row row, Collector<Row> collector) throws Exception {
        boolean write;
        this.collector = collector;
        do {
            drainRead(this.collector, false);
            write = this.dataExchange.write(DLUtils.encodeStringValue(row));
            if (!write) {
                Thread.yield();
            }
        } while (!write);
    }

    public TypeInformation<Row> getProducedType() {
        return this.outTI;
    }

    private void drainRead(Collector<Row> collector, boolean z) {
        Row row;
        while (true) {
            try {
                row = (Row) this.dataExchange.read(z);
            } catch (InterruptedIOException e) {
                LOG.info("{} Reading from is interrupted, canceling the server", this.mlContext.getIdentity());
                this.serverFuture.cancel(true);
            } catch (IOException e2) {
                LOG.error("Fail to read data from python.", e2);
                throw new AkUnclassifiedErrorException("Fail to read data from python.", e2);
            }
            if (row == null) {
                return;
            } else {
                collector.collect(row);
            }
        }
    }
}
