Skip to content

Commit

Permalink
Cleanup Bogus Errors (#3190)
Browse files Browse the repository at this point in the history
Removed several bogus errors:
1. Error processing frame: RuntimeError: cannot schedule new futures
after shutdown
This can happen when the message arrives for a cell already being
shutdown. Changed it to debug.
2. Logical Error: Endpoint is already removed
This is actually not an error. The weak-ref set entry can be removed
when the ref is gone. Changed it to debug.
3. Added name to several threads to help debugging the dangling thread
issue.


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Quick tests passed locally by running `./runtest.sh`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated.

---------

Co-authored-by: Ziyue Xu <[email protected]>
  • Loading branch information
nvidianz and ZiyueXu77 authored Jan 31, 2025
1 parent b26b224 commit f8dd354
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 12 deletions.
2 changes: 1 addition & 1 deletion nvflare/apis/impl/wf_comm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, task_check_period=0.2):
self._client_task_map = {} # client_task_id => client_task
self._all_done = False
self._task_lock = Lock()
self._task_monitor = threading.Thread(target=self._monitor_tasks, args=(), daemon=True)
self._task_monitor = threading.Thread(target=self._monitor_tasks, args=(), name="wf_task", daemon=True)
self._task_check_period = task_check_period
self._dead_client_grace = 60.0
self._dead_clients = {} # clients reported dead: name => _DeadClientStatus
Expand Down
4 changes: 2 additions & 2 deletions nvflare/apis/utils/reliable_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def enable(cls, fl_ctx: FLContext):
)

cls._query_interval = query_interval
cls._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_request_workers)
cls._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_request_workers, thread_name_prefix="rm")
engine = fl_ctx.get_engine()
engine.register_aux_message_handler(
topic=TOPIC_RELIABLE_REQUEST,
Expand All @@ -397,7 +397,7 @@ def enable(cls, fl_ctx: FLContext):
topic=TOPIC_RELIABLE_REPLY,
message_handle_func=cls._receive_reply,
)
t = threading.Thread(target=cls._monitor_req_receivers, daemon=True)
t = threading.Thread(target=cls._monitor_req_receivers, name="rm_monitor", daemon=True)
t.start()
cls._logger.info(f"enabled reliable message: {max_request_workers=} {query_interval=}")

Expand Down
6 changes: 4 additions & 2 deletions nvflare/fuel/f3/cellnet/core_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,7 +1439,7 @@ def queue_message(self, channel: str, topic: str, targets: Union[str, List[str]]
with self.bulk_lock:
if self.bulk_checker is None:
self.logger.info(f"{self.my_info.fqcn}: starting bulk_checker")
self.bulk_checker = threading.Thread(target=self._check_bulk, name="check_bulk_msg")
self.bulk_checker = threading.Thread(target=self._check_bulk, name="check_bulk_msg", daemon=True)
self.bulk_checker.start()
self.logger.info(f"{self.my_info.fqcn}: started bulk_checker")
for t in targets:
Expand Down Expand Up @@ -1474,7 +1474,9 @@ def _receive_bulk_message(self, request: Message):
with self.bulk_msg_lock:
if self.bulk_processor is None:
self.logger.debug(f"{self.my_info.fqcn}: starting bulk message processor")
self.bulk_processor = threading.Thread(target=self._process_bulk_messages, name="process_bulk_msg")
self.bulk_processor = threading.Thread(
target=self._process_bulk_messages, name="process_bulk_msg", daemon=True
)
self.bulk_processor.start()
self.logger.debug(f"{self.my_info.fqcn}: started bulk message processor")
self.bulk_messages.append(request)
Expand Down
6 changes: 2 additions & 4 deletions nvflare/fuel/f3/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from nvflare.fuel.f3.endpoint import Endpoint, EndpointMonitor
from nvflare.fuel.f3.message import Message, MessageReceiver
from nvflare.fuel.f3.sfm.conn_manager import ConnManager, Mode
from nvflare.security.logging import secure_format_exception

log = logging.getLogger(__name__)
_running_instances = weakref.WeakSet()
Expand Down Expand Up @@ -86,9 +85,8 @@ def stop(self):
try:
_running_instances.remove(self)
except KeyError as ex:
log.error(
f"Logical error, communicator {self.local_endpoint.name} is not started: {secure_format_exception(ex)}"
)
# For weak-ref set, the entry may be removed automatically if no other ref so this is not an error
log.debug(f"Weak-ref for Communicator {self.local_endpoint.name} is already removed")

log.debug(f"Communicator endpoint: {self.local_endpoint.name} has stopped")

Expand Down
2 changes: 1 addition & 1 deletion nvflare/fuel/f3/drivers/aio_grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, aio_ctx: AioContext, connector: ConnectorInfo, conn_props: di
if conf.get_bool_var("simulate_unstable_network", default=False):
if context:
# only server side
self.disconn = threading.Thread(target=self._disconnect, daemon=True)
self.disconn = threading.Thread(target=self._disconnect, name="grpc_disc", daemon=True)
self.disconn.start()

def _disconnect(self):
Expand Down
2 changes: 1 addition & 1 deletion nvflare/fuel/f3/drivers/grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def Stream(self, request_iterator, context):
self.logger.debug(f"SERVER created connection in thread {ct.name}")
self.server.driver.add_connection(connection)
self.logger.debug(f"SERVER created read_loop thread in thread {ct.name}")
t = threading.Thread(target=connection.read_loop, args=(request_iterator,), daemon=True)
t = threading.Thread(target=connection.read_loop, args=(request_iterator,), name="grpc_reader", daemon=True)
t.start()
yield from connection.generate_output()
except Exception as ex:
Expand Down
7 changes: 7 additions & 0 deletions nvflare/fuel/f3/sfm/conn_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(self, local_endpoint: Endpoint):
self.receivers: Dict[int, MessageReceiver] = {}

self.started = False
self.stopped = False
self.conn_mgr_executor = ThreadPoolExecutor(CONN_THREAD_POOL_SIZE, "conn_mgr")
self.frame_mgr_executor = ThreadPoolExecutor(FRAME_THREAD_POOL_SIZE, "frame_mgr")
self.lock = threading.Lock()
Expand Down Expand Up @@ -154,6 +155,8 @@ def stop(self):
self.conn_mgr_executor.shutdown(True)
self.frame_mgr_executor.shutdown(True)

self.stopped = True

def find_endpoint(self, name: str) -> Optional[Endpoint]:

sfm_endpoint = self.sfm_endpoints.get(name)
Expand Down Expand Up @@ -369,6 +372,10 @@ def process_frame_task(self, sfm_conn: SfmConnection, frame: BytesAlike):
log.debug(secure_format_traceback())

def process_frame(self, sfm_conn: SfmConnection, frame: BytesAlike):
if self.stopped:
log.debug(f"Frame received after shutdown for connection {sfm_conn.get_name()}")
return

self.frame_mgr_executor.submit(self.process_frame_task, sfm_conn, frame)

def update_endpoint(self, sfm_conn: SfmConnection, data: dict):
Expand Down
25 changes: 24 additions & 1 deletion nvflare/fuel/f3/streaming/stream_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import threading
import time
Expand All @@ -22,11 +23,33 @@
STREAM_THREAD_POOL_SIZE = 128
ONE_MB = 1024 * 1024

stream_thread_pool = ThreadPoolExecutor(STREAM_THREAD_POOL_SIZE, "stm")
lock = threading.Lock()
sid_base = int((time.time() + os.getpid()) * 1000000) # microseconds
stream_count = 0

log = logging.getLogger(__name__)


class CheckedExecutor(ThreadPoolExecutor):
"""This executor ignores task after shutting down"""

def __init__(self, max_workers=None, thread_name_prefix=""):
super().__init__(max_workers, thread_name_prefix)
self.stopped = False

def shutdown(self, wait=True):
self.stopped = True
super().shutdown(wait)

def submit(self, fn, *args, **kwargs):
if self.stopped:
log.debug(f"Call {fn} is ignored after streaming shutting down")
else:
super().submit(fn, *args, **kwargs)


stream_thread_pool = CheckedExecutor(STREAM_THREAD_POOL_SIZE, "stm")


def wrap_view(buffer: BytesAlike) -> memoryview:
if isinstance(buffer, memoryview):
Expand Down

0 comments on commit f8dd354

Please sign in to comment.