diff --git a/nvflare/apis/impl/wf_comm_server.py b/nvflare/apis/impl/wf_comm_server.py index 21fd0c6033..7a9814d114 100644 --- a/nvflare/apis/impl/wf_comm_server.py +++ b/nvflare/apis/impl/wf_comm_server.py @@ -93,7 +93,7 @@ def __init__(self, task_check_period=0.2): self._client_task_map = {} # client_task_id => client_task self._all_done = False self._task_lock = Lock() - self._task_monitor = threading.Thread(target=self._monitor_tasks, args=(), daemon=True) + self._task_monitor = threading.Thread(target=self._monitor_tasks, args=(), name="wf_task", daemon=True) self._task_check_period = task_check_period self._dead_client_grace = 60.0 self._dead_clients = {} # clients reported dead: name => _DeadClientStatus diff --git a/nvflare/apis/utils/reliable_message.py b/nvflare/apis/utils/reliable_message.py index 8d3923bf04..c10908c92e 100644 --- a/nvflare/apis/utils/reliable_message.py +++ b/nvflare/apis/utils/reliable_message.py @@ -387,7 +387,7 @@ def enable(cls, fl_ctx: FLContext): ) cls._query_interval = query_interval - cls._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_request_workers) + cls._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_request_workers, thread_name_prefix="rm") engine = fl_ctx.get_engine() engine.register_aux_message_handler( topic=TOPIC_RELIABLE_REQUEST, @@ -397,7 +397,7 @@ def enable(cls, fl_ctx: FLContext): topic=TOPIC_RELIABLE_REPLY, message_handle_func=cls._receive_reply, ) - t = threading.Thread(target=cls._monitor_req_receivers, daemon=True) + t = threading.Thread(target=cls._monitor_req_receivers, name="rm_monitor", daemon=True) t.start() cls._logger.info(f"enabled reliable message: {max_request_workers=} {query_interval=}") diff --git a/nvflare/fuel/f3/cellnet/core_cell.py b/nvflare/fuel/f3/cellnet/core_cell.py index 4e8b68fa8b..7228e7a8c8 100644 --- a/nvflare/fuel/f3/cellnet/core_cell.py +++ b/nvflare/fuel/f3/cellnet/core_cell.py @@ -1439,7 +1439,7 @@ def queue_message(self, channel: str, topic: str, targets: Union[str, List[str]] with self.bulk_lock: if self.bulk_checker is None: self.logger.info(f"{self.my_info.fqcn}: starting bulk_checker") - self.bulk_checker = threading.Thread(target=self._check_bulk, name="check_bulk_msg") + self.bulk_checker = threading.Thread(target=self._check_bulk, name="check_bulk_msg", daemon=True) self.bulk_checker.start() self.logger.info(f"{self.my_info.fqcn}: started bulk_checker") for t in targets: @@ -1474,7 +1474,9 @@ def _receive_bulk_message(self, request: Message): with self.bulk_msg_lock: if self.bulk_processor is None: self.logger.debug(f"{self.my_info.fqcn}: starting bulk message processor") - self.bulk_processor = threading.Thread(target=self._process_bulk_messages, name="process_bulk_msg") + self.bulk_processor = threading.Thread( + target=self._process_bulk_messages, name="process_bulk_msg", daemon=True + ) self.bulk_processor.start() self.logger.debug(f"{self.my_info.fqcn}: started bulk message processor") self.bulk_messages.append(request) diff --git a/nvflare/fuel/f3/communicator.py b/nvflare/fuel/f3/communicator.py index 02714fd84e..dbd4e86298 100644 --- a/nvflare/fuel/f3/communicator.py +++ b/nvflare/fuel/f3/communicator.py @@ -27,7 +27,6 @@ from nvflare.fuel.f3.endpoint import Endpoint, EndpointMonitor from nvflare.fuel.f3.message import Message, MessageReceiver from nvflare.fuel.f3.sfm.conn_manager import ConnManager, Mode -from nvflare.security.logging import secure_format_exception log = logging.getLogger(__name__) _running_instances = weakref.WeakSet() @@ -86,9 +85,8 @@ def stop(self): try: _running_instances.remove(self) except KeyError as ex: - log.error( - f"Logical error, communicator {self.local_endpoint.name} is not started: {secure_format_exception(ex)}" - ) + # For weak-ref set, the entry may be removed automatically if no other ref so this is not an error + log.debug(f"Weak-ref for Communicator {self.local_endpoint.name} is already removed") log.debug(f"Communicator endpoint: {self.local_endpoint.name} has stopped") diff --git a/nvflare/fuel/f3/drivers/aio_grpc_driver.py b/nvflare/fuel/f3/drivers/aio_grpc_driver.py index 972e8fc05f..aa71b0441a 100644 --- a/nvflare/fuel/f3/drivers/aio_grpc_driver.py +++ b/nvflare/fuel/f3/drivers/aio_grpc_driver.py @@ -74,7 +74,7 @@ def __init__(self, aio_ctx: AioContext, connector: ConnectorInfo, conn_props: di if conf.get_bool_var("simulate_unstable_network", default=False): if context: # only server side - self.disconn = threading.Thread(target=self._disconnect, daemon=True) + self.disconn = threading.Thread(target=self._disconnect, name="grpc_disc", daemon=True) self.disconn.start() def _disconnect(self): diff --git a/nvflare/fuel/f3/drivers/grpc_driver.py b/nvflare/fuel/f3/drivers/grpc_driver.py index 2cd2d0bb1b..7cdd444360 100644 --- a/nvflare/fuel/f3/drivers/grpc_driver.py +++ b/nvflare/fuel/f3/drivers/grpc_driver.py @@ -145,7 +145,7 @@ def Stream(self, request_iterator, context): self.logger.debug(f"SERVER created connection in thread {ct.name}") self.server.driver.add_connection(connection) self.logger.debug(f"SERVER created read_loop thread in thread {ct.name}") - t = threading.Thread(target=connection.read_loop, args=(request_iterator,), daemon=True) + t = threading.Thread(target=connection.read_loop, args=(request_iterator,), name="grpc_reader", daemon=True) t.start() yield from connection.generate_output() except Exception as ex: diff --git a/nvflare/fuel/f3/sfm/conn_manager.py b/nvflare/fuel/f3/sfm/conn_manager.py index 28a345ff20..4baa58502b 100644 --- a/nvflare/fuel/f3/sfm/conn_manager.py +++ b/nvflare/fuel/f3/sfm/conn_manager.py @@ -82,6 +82,7 @@ def __init__(self, local_endpoint: Endpoint): self.receivers: Dict[int, MessageReceiver] = {} self.started = False + self.stopped = False self.conn_mgr_executor = ThreadPoolExecutor(CONN_THREAD_POOL_SIZE, "conn_mgr") self.frame_mgr_executor = ThreadPoolExecutor(FRAME_THREAD_POOL_SIZE, "frame_mgr") self.lock = threading.Lock() @@ -154,6 +155,8 @@ def stop(self): self.conn_mgr_executor.shutdown(True) self.frame_mgr_executor.shutdown(True) + self.stopped = True + def find_endpoint(self, name: str) -> Optional[Endpoint]: sfm_endpoint = self.sfm_endpoints.get(name) @@ -369,6 +372,10 @@ def process_frame_task(self, sfm_conn: SfmConnection, frame: BytesAlike): log.debug(secure_format_traceback()) def process_frame(self, sfm_conn: SfmConnection, frame: BytesAlike): + if self.stopped: + log.debug(f"Frame received after shutdown for connection {sfm_conn.get_name()}") + return + self.frame_mgr_executor.submit(self.process_frame_task, sfm_conn, frame) def update_endpoint(self, sfm_conn: SfmConnection, data: dict): diff --git a/nvflare/fuel/f3/streaming/stream_utils.py b/nvflare/fuel/f3/streaming/stream_utils.py index 6be7d2097f..2c5ae53f31 100644 --- a/nvflare/fuel/f3/streaming/stream_utils.py +++ b/nvflare/fuel/f3/streaming/stream_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import threading import time @@ -22,11 +23,33 @@ STREAM_THREAD_POOL_SIZE = 128 ONE_MB = 1024 * 1024 -stream_thread_pool = ThreadPoolExecutor(STREAM_THREAD_POOL_SIZE, "stm") lock = threading.Lock() sid_base = int((time.time() + os.getpid()) * 1000000) # microseconds stream_count = 0 +log = logging.getLogger(__name__) + + +class CheckedExecutor(ThreadPoolExecutor): + """This executor ignores task after shutting down""" + + def __init__(self, max_workers=None, thread_name_prefix=""): + super().__init__(max_workers, thread_name_prefix) + self.stopped = False + + def shutdown(self, wait=True): + self.stopped = True + super().shutdown(wait) + + def submit(self, fn, *args, **kwargs): + if self.stopped: + log.debug(f"Call {fn} is ignored after streaming shutting down") + else: + super().submit(fn, *args, **kwargs) + + +stream_thread_pool = CheckedExecutor(STREAM_THREAD_POOL_SIZE, "stm") + def wrap_view(buffer: BytesAlike) -> memoryview: if isinstance(buffer, memoryview):