From 539e221ebbd798e8de89349302f9913ea9d6704b Mon Sep 17 00:00:00 2001 From: gaoren002 Date: Mon, 9 Sep 2024 14:09:49 +0800 Subject: [PATCH 1/4] Fix the connection between server and tracker when using WSL2 --- python/tvm/exec/rpc_server.py | 225 +++++++++++++++++++--------------- 1 file changed, 123 insertions(+), 102 deletions(-) diff --git a/python/tvm/exec/rpc_server.py b/python/tvm/exec/rpc_server.py index 4da88bcdebfc..922bbd1c9954 100644 --- a/python/tvm/exec/rpc_server.py +++ b/python/tvm/exec/rpc_server.py @@ -1,102 +1,123 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=redefined-outer-name, invalid-name -"""Start an RPC server""" -import argparse -import logging -from .. import rpc - - -def main(args): - """Main function - - Parameters - ---------- - args : argparse.Namespace - parsed args from command-line invocation - """ - if args.tracker: - url, port = args.tracker.rsplit(":", 1) - port = int(port) - tracker_addr = (url, port) - if not args.key: - raise RuntimeError("Need key to present type of resource when tracker is available") - else: - tracker_addr = None - - server = rpc.Server( - args.host, - args.port, - args.port_end, - is_proxy=args.through_proxy, - key=args.key, - tracker_addr=tracker_addr, - load_library=args.load_library, - custom_addr=args.custom_addr, - silent=args.silent, - no_fork=not args.fork, - ) - server.proc.join() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--host", type=str, default="0.0.0.0", help="The host IP address the tracker binds to" - ) - parser.add_argument("--port", type=int, default=9090, help="The port of the RPC") - parser.add_argument( - "--through-proxy", - dest="through_proxy", - action="store_true", - help=( - "Whether this server provide service through a proxy. If this is true, the host and" - "port actually is the address of the proxy." - ), - ) - parser.add_argument("--port-end", type=int, default=9199, help="The end search port of the RPC") - parser.add_argument( - "--tracker", - type=str, - help=("The address of RPC tracker in host:port format. " "e.g. (10.77.1.234:9190)"), - ) - parser.add_argument( - "--key", type=str, default="", help="The key used to identify the device type in tracker." - ) - parser.add_argument("--silent", action="store_true", help="Whether run in silent mode.") - parser.add_argument("--load-library", type=str, help="Additional library to load") - parser.add_argument( - "--no-fork", - dest="fork", - action="store_false", - help="Use spawn mode to avoid fork. This option \ - is able to avoid potential fork problems with Metal, OpenCL \ - and ROCM compilers.", - ) - parser.add_argument( - "--custom-addr", type=str, help="Custom IP Address to Report to RPC Tracker" - ) - - parser.set_defaults(fork=True) - args = parser.parse_args() - logging.basicConfig(level=logging.INFO) - if not args.fork is False and not args.silent: - logging.info( - "If you are running ROCM/Metal, fork will cause " - "compiler internal error. Try to launch with arg ```--no-fork```" - ) - main(args) +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-outer-name, invalid-name +"""Start an RPC server""" +import argparse +import logging +from .. import rpc +import socket + + +def get_local_ip(): + try: + # create UDP socket + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + # connect to a outer server, but we don't send data + s.connect(("8.8.8.8", 80)) + # get local ip + local_ip = s.getsockname()[0] + s.close() + return local_ip + except Exception: + return None + + +def main(args): + """Main function + + Parameters + ---------- + args : argparse.Namespace + parsed args from command-line invocation + """ + if args.tracker: + url, port = args.tracker.rsplit(":", 1) + port = int(port) + tracker_addr = (url, port) + if not args.key: + raise RuntimeError("Need key to present type of resource when tracker is available") + else: + tracker_addr = None + external_ip = get_local_ip() + + # + if external_ip and not args.custom_addr: + custom_addr = f"{external_ip}" + else: + custom_addr = args.custom_addr + server = rpc.Server( + args.host, + args.port, + args.port_end, + is_proxy=args.through_proxy, + key=args.key, + tracker_addr=tracker_addr, + load_library=args.load_library, + custom_addr=custom_addr, + silent=args.silent, + no_fork=not args.fork, + ) + server.proc.join() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="The host IP address the tracker binds to" + ) + parser.add_argument("--port", type=int, default=9090, help="The port of the RPC") + parser.add_argument( + "--through-proxy", + dest="through_proxy", + action="store_true", + help=( + "Whether this server provide service through a proxy. If this is true, the host and" + "port actually is the address of the proxy." + ), + ) + parser.add_argument("--port-end", type=int, default=9199, help="The end search port of the RPC") + parser.add_argument( + "--tracker", + type=str, + help=("The address of RPC tracker in host:port format. " "e.g. (10.77.1.234:9190)"), + ) + parser.add_argument( + "--key", type=str, default="", help="The key used to identify the device type in tracker." + ) + parser.add_argument("--silent", action="store_true", help="Whether run in silent mode.") + parser.add_argument("--load-library", type=str, help="Additional library to load") + parser.add_argument( + "--no-fork", + dest="fork", + action="store_false", + help="Use spawn mode to avoid fork. This option \ + is able to avoid potential fork problems with Metal, OpenCL \ + and ROCM compilers.", + ) + parser.add_argument( + "--custom-addr", type=str, help="Custom IP Address to Report to RPC Tracker" + ) + + parser.set_defaults(fork=True) + args = parser.parse_args() + logging.basicConfig(level=logging.INFO) + if not args.fork is False and not args.silent: + logging.info( + "If you are running ROCM/Metal, fork will cause " + "compiler internal error. Try to launch with arg ```--no-fork```" + ) + main(args) From 9e3b6ffdd1dbf985e06b744696a4a2ede669ff80 Mon Sep 17 00:00:00 2001 From: gaoren002 Date: Mon, 9 Sep 2024 14:48:59 +0800 Subject: [PATCH 2/4] [Fix] the connection between server and tracker when using WSL2 --- python/tvm/exec/rpc_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/exec/rpc_server.py b/python/tvm/exec/rpc_server.py index 922bbd1c9954..cb29b7446ffb 100644 --- a/python/tvm/exec/rpc_server.py +++ b/python/tvm/exec/rpc_server.py @@ -53,8 +53,8 @@ def main(args): else: tracker_addr = None external_ip = get_local_ip() - - # + + # if external_ip and not args.custom_addr: custom_addr = f"{external_ip}" else: From f125df30cb33e6458709bc86c06f0e4c2f06a59d Mon Sep 17 00:00:00 2001 From: gaoren002 Date: Mon, 9 Sep 2024 15:08:13 +0800 Subject: [PATCH 3/4] [Fix] the connection between server and tracker when using WSL2 --- python/tvm/exec/rpc_server.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/python/tvm/exec/rpc_server.py b/python/tvm/exec/rpc_server.py index cb29b7446ffb..1ac25aa1f4e9 100644 --- a/python/tvm/exec/rpc_server.py +++ b/python/tvm/exec/rpc_server.py @@ -18,11 +18,20 @@ """Start an RPC server""" import argparse import logging -from .. import rpc import socket +from .. import rpc + def get_local_ip(): + """ + Attempt to get the local IP address of the machine. + + Returns: + -------- + str or None + The IP address of the machine as a string if successful; None if failed. + """ try: # create UDP socket s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -32,7 +41,7 @@ def get_local_ip(): local_ip = s.getsockname()[0] s.close() return local_ip - except Exception: + except (socket.error, OSError): return None From 4108456e145ef590a5192f11914bc4a85aa16441 Mon Sep 17 00:00:00 2001 From: gaoren002 Date: Mon, 9 Sep 2024 16:00:07 +0800 Subject: [PATCH 4/4] [Fix] the connection between server and tracker when using WSL2 --- python/tvm/exec/rpc_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/exec/rpc_server.py b/python/tvm/exec/rpc_server.py index 1ac25aa1f4e9..d563f3fecf67 100644 --- a/python/tvm/exec/rpc_server.py +++ b/python/tvm/exec/rpc_server.py @@ -22,7 +22,6 @@ from .. import rpc - def get_local_ip(): """ Attempt to get the local IP address of the machine.