Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import sys
import warnings

# Avoid memory fragmentation and peak reserved memory increasing over time
# To overwrite, set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:False
if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ:
if "torch" in sys.modules:
warnings.warn(
"The 'torch' module has already been imported. "
"Setting PYTORCH_CUDA_ALLOC_CONF may not have an effect."
"For best results, set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True before importing 'torch'."
)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Check at the top-level that torchao is installed.
# This is better than doing it at every import site.
Expand Down
22 changes: 11 additions & 11 deletions torchtune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@


import os
import sys
import warnings

# Avoid memory fragmentation and peak reserved memory increasing over time
# To overwrite, set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:False
if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ:
if "torch" in sys.modules:
warnings.warn(
"The 'torch' module has already been imported. "
"Setting PYTORCH_CUDA_ALLOC_CONF may not have an effect."
"For best results, set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True before importing 'torch'."
)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch

if torch.cuda.is_available():
ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", None)

if ca_config is None or "expandable_segments:False" not in ca_config:
try:
# Avoid memory fragmentation and peak reserved memory increasing over time.
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
except RuntimeError:
warnings.warn("Setting expandable_segments:True for CUDA allocator failed.")


# Check at the top-level that torchao is installed.
Expand Down