diff --git a/nvflare/fuel/f3/streaming/stream_utils.py b/nvflare/fuel/f3/streaming/stream_utils.py index 6be7d2097f..8357d44eaf 100644 --- a/nvflare/fuel/f3/streaming/stream_utils.py +++ b/nvflare/fuel/f3/streaming/stream_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import threading import time @@ -22,11 +23,33 @@ STREAM_THREAD_POOL_SIZE = 128 ONE_MB = 1024 * 1024 -stream_thread_pool = ThreadPoolExecutor(STREAM_THREAD_POOL_SIZE, "stm") lock = threading.Lock() sid_base = int((time.time() + os.getpid()) * 1000000) # microseconds stream_count = 0 +log = logging.getLogger(__name__) + + +class CheckedExecutor(ThreadPoolExecutor): + """This executor ignores task after shutting down""" + + def __init__(self, max_workers=None, thread_name_prefix=""): + super().__init__(max_workers, thread_name_prefix) + self.stopped = False + + def shutdown(self, wait=True): + super().shutdown(wait) + self.stopped = True + + def submit(self, fn, *args, **kwargs): + if self.stopped: + log.debug(f"Call {fn} is ignored after streaming shutting down") + else: + super().submit(fn, *args, **kwargs) + + +stream_thread_pool = CheckedExecutor(STREAM_THREAD_POOL_SIZE, "stm") + def wrap_view(buffer: BytesAlike) -> memoryview: if isinstance(buffer, memoryview):