diff --git a/tests/__init__.py b/tests/__init__.py index 3633a4ab70..0e9503fb7d 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -4,20 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os -import sys -import warnings - -# Avoid memory fragmentation and peak reserved memory increasing over time -# To overwrite, set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:False -if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ: - if "torch" in sys.modules: - warnings.warn( - "The 'torch' module has already been imported. " - "Setting PYTORCH_CUDA_ALLOC_CONF may not have an effect." - "For best results, set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True before importing 'torch'." - ) - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Check at the top-level that torchao is installed. # This is better than doing it at every import site. diff --git a/torchtune/__init__.py b/torchtune/__init__.py index d744705292..aef6275ce6 100644 --- a/torchtune/__init__.py +++ b/torchtune/__init__.py @@ -8,19 +8,19 @@ import os -import sys import warnings -# Avoid memory fragmentation and peak reserved memory increasing over time -# To overwrite, set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:False -if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ: - if "torch" in sys.modules: - warnings.warn( - "The 'torch' module has already been imported. " - "Setting PYTORCH_CUDA_ALLOC_CONF may not have an effect." - "For best results, set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True before importing 'torch'." - ) - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +if torch.cuda.is_available(): + ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", None) + + if ca_config is None or "expandable_segments:False" not in ca_config: + try: + # Avoid memory fragmentation and peak reserved memory increasing over time. + torch.cuda.memory._set_allocator_settings("expandable_segments:True") + except RuntimeError: + warnings.warn("Setting expandable_segments:True for CUDA allocator failed.") # Check at the top-level that torchao is installed.