From 7c9022e3f64b9ce661b9e26a3e4be79a60fb8f6f Mon Sep 17 00:00:00 2001 From: Mircea Mironenco Date: Mon, 14 Jul 2025 19:38:36 +0300 Subject: [PATCH] Set expandable_segments explicitly via cuda memory API --- tests/__init__.py | 14 -------------- torchtune/__init__.py | 22 +++++++++++----------- 2 files changed, 11 insertions(+), 25 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index 3633a4ab70..0e9503fb7d 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -4,20 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os -import sys -import warnings - -# Avoid memory fragmentation and peak reserved memory increasing over time -# To overwrite, set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:False -if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ: - if "torch" in sys.modules: - warnings.warn( - "The 'torch' module has already been imported. " - "Setting PYTORCH_CUDA_ALLOC_CONF may not have an effect." - "For best results, set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True before importing 'torch'." - ) - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Check at the top-level that torchao is installed. # This is better than doing it at every import site. diff --git a/torchtune/__init__.py b/torchtune/__init__.py index d744705292..aef6275ce6 100644 --- a/torchtune/__init__.py +++ b/torchtune/__init__.py @@ -8,19 +8,19 @@ import os -import sys import warnings -# Avoid memory fragmentation and peak reserved memory increasing over time -# To overwrite, set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:False -if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ: - if "torch" in sys.modules: - warnings.warn( - "The 'torch' module has already been imported. " - "Setting PYTORCH_CUDA_ALLOC_CONF may not have an effect." - "For best results, set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True before importing 'torch'." - ) - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +if torch.cuda.is_available(): + ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", None) + + if ca_config is None or "expandable_segments:False" not in ca_config: + try: + # Avoid memory fragmentation and peak reserved memory increasing over time. + torch.cuda.memory._set_allocator_settings("expandable_segments:True") + except RuntimeError: + warnings.warn("Setting expandable_segments:True for CUDA allocator failed.") # Check at the top-level that torchao is installed.