package com.alibaba.alink.common.dl;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.flink.ml.cluster.master.meta.AMMetaImpl;
import com.alibaba.flink.ml.cluster.node.MLContext;
import com.alibaba.flink.ml.cluster.node.runner.ExecutionStatus;
import com.alibaba.flink.ml.cluster.node.runner.FlinkKillException;
import com.alibaba.flink.ml.cluster.node.runner.MLRunner;
import com.alibaba.flink.ml.cluster.node.runner.ScriptRunner;
import com.alibaba.flink.ml.cluster.node.runner.ScriptRunnerFactory;
import com.alibaba.flink.ml.cluster.role.BaseRole;
import com.alibaba.flink.ml.cluster.role.PsRole;
import com.alibaba.flink.ml.cluster.role.WorkerRole;
import com.alibaba.flink.ml.cluster.rpc.NodeServer;
import com.alibaba.flink.ml.proto.MLClusterDef;
import com.alibaba.flink.ml.proto.NodeSpec;
import com.alibaba.flink.ml.util.IpHostUtil;
import com.alibaba.flink.ml.util.MLException;
import com.alibaba.flink.ml.util.ProtoUtil;
import java.io.IOException;
import java.io.Serializable;
import java.net.ServerSocket;
import java.util.Map;
import org.apache.commons.io.IOUtils;
import org.apache.flink.api.java.tuple.Tuple2;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/common/dl/DLRunner.class */
public class DLRunner implements MLRunner, Serializable {
    public static final String IPS = "Alink:dl_ips";
    public static final String PORTS = "Alink:dl_ports";
    private static final Logger LOG = LoggerFactory.getLogger(DLRunner.class);
    protected long version = 0;
    protected String localIp;
    protected NodeServer server;
    protected volatile MLContext mlContext;
    protected ScriptRunner scriptRunner;
    protected ExecutionStatus resultStatus;
    protected ExecutionStatus currentResultStatus;
    protected MLClusterDef mlClusterDef;

    public DLRunner(MLContext mLContext, NodeServer nodeServer) {
        this.mlContext = mLContext;
        this.server = nodeServer;
    }

    public void registerNode() throws Exception {
    }

    public static Tuple2<BaseRole, Integer> getRoleAndIndex(int i, int i2) {
        return Tuple2.of(i < i2 ? new WorkerRole() : new PsRole(), Integer.valueOf(i < i2 ? i : i - i2));
    }

    public void getClusterInfo() throws InterruptedException, IOException {
        AMMetaImpl aMMetaImpl = new AMMetaImpl(this.mlContext);
        Map properties = this.mlContext.getProperties();
        if (!properties.containsKey(IPS) || !properties.containsKey(PORTS)) {
            LOG.info("Running in standalone mode.");
            ServerSocket freeSocket = IpHostUtil.getFreeSocket();
            NodeSpec build = NodeSpec.newBuilder().setIp(this.localIp).setIndex(this.mlContext.getIndex()).setClientPort(this.server.getPort()).setRoleName(this.mlContext.getRoleName()).putProps("sys:tf_port", String.valueOf(freeSocket.getLocalPort())).build();
            freeSocket.close();
            aMMetaImpl.saveNodeSpec(build);
            this.mlClusterDef = aMMetaImpl.restoreClusterDef();
            return;
        }
        LOG.info("Running in cluster mode.");
        int parseInt = Integer.parseInt((String) properties.get(DLConstants.NUM_WORKERS));
        String[] strArr = (String[]) JsonConverter.fromJson((String) properties.get(IPS), String[].class);
        int[] iArr = (int[]) JsonConverter.fromJson((String) properties.get(PORTS), int[].class);
        for (int i = 0; i < strArr.length; i++) {
            Tuple2<BaseRole, Integer> roleAndIndex = getRoleAndIndex(i, parseInt);
            aMMetaImpl.saveNodeSpec(NodeSpec.newBuilder().setIp(strArr[i]).setIndex(((Integer) roleAndIndex.f1).intValue()).setClientPort(0).setRoleName(((BaseRole) roleAndIndex.f0).name()).putProps("sys:tf_port", String.valueOf(iArr[i])).build());
        }
        this.mlClusterDef = aMMetaImpl.restoreClusterDef();
    }

    protected void checkEnd() throws MLException {
        if (this.resultStatus == ExecutionStatus.KILLED_BY_FLINK) {
            throw new FlinkKillException("Exit per request.");
        }
    }

    public void run() {
        this.resultStatus = ExecutionStatus.RUNNING;
        this.currentResultStatus = ExecutionStatus.RUNNING;
        try {
            try {
                this.localIp = IpHostUtil.getIpAddress();
                getClusterInfo();
                AkPreconditions.checkNotNull(this.mlClusterDef);
                checkEnd();
                resetMLContext();
                checkEnd();
                runScript();
                checkEnd();
                LOG.info("run script.");
                this.currentResultStatus = ExecutionStatus.SUCCEED;
                stopExecution(this.currentResultStatus == ExecutionStatus.SUCCEED);
                this.resultStatus = this.currentResultStatus;
                this.server.setAmCommand(NodeServer.AMCommand.STOP);
            } catch (Throwable th) {
                if ((th instanceof FlinkKillException) || (th instanceof InterruptedException)) {
                    LOG.info("{} killed by flink.", this.mlContext.getIdentity());
                    this.currentResultStatus = ExecutionStatus.KILLED_BY_FLINK;
                } else {
                    LOG.error("Got exception during python running", th);
                    this.mlContext.addFailNum();
                    this.currentResultStatus = ExecutionStatus.FAILED;
                }
                stopExecution(this.currentResultStatus == ExecutionStatus.SUCCEED);
                this.resultStatus = this.currentResultStatus;
                this.server.setAmCommand(NodeServer.AMCommand.STOP);
            }
        } catch (Throwable th2) {
            stopExecution(this.currentResultStatus == ExecutionStatus.SUCCEED);
            this.resultStatus = this.currentResultStatus;
            this.server.setAmCommand(NodeServer.AMCommand.STOP);
            throw th2;
        }
    }

    public void runScript() throws Exception {
        this.mlContext.getProperties().put("script_runner_class", ProcessPythonRunnerV2.class.getCanonicalName());
        this.scriptRunner = ScriptRunnerFactory.getScriptRunner(this.mlContext);
        this.scriptRunner.runScript();
    }

    public void resetMLContext() {
        String protoToJson = ProtoUtil.protoToJson(this.mlClusterDef);
        LOG.info("java cluster:" + protoToJson);
        if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
            System.out.println("java cluster:" + protoToJson);
        }
        this.mlContext.getProperties().put("cluster", protoToJson);
        this.mlContext.setNodeServerIP(this.localIp);
        this.mlContext.setNodeServerPort(this.server.getPort());
    }

    public void startHeartBeat() throws Exception {
    }

    public void getCurrentJobVersion() {
    }

    public void initAMClient() throws Exception {
    }

    public void waitClusterRunning() throws InterruptedException, MLException {
    }

    protected void stopExecution(boolean z) {
        if (this.scriptRunner != null) {
            IOUtils.closeQuietly(this.scriptRunner);
            this.scriptRunner = null;
        }
        if (z) {
            return;
        }
        this.mlContext.reset();
    }

    public ExecutionStatus getResultStatus() {
        return this.resultStatus;
    }

    public void notifyStop() {
        if (this.scriptRunner != null) {
            this.scriptRunner.notifyKillSignal();
        }
        this.resultStatus = ExecutionStatus.KILLED_BY_FLINK;
    }
}
