Skip to content

Commit d067dd8

Browse files
authored
Git Theta Low Memory Mode (#234)
The original file, piped into the filter from git, it stored in a temporary file before being read by the checkpoint plugin. When cleaning parameter groups, we free the memory for the group after it has been written to disk and converted to metadata. This is the most everything we can do until the dl native formats support streaming.
1 parent 019e124 commit d067dd8

8 files changed

+54
-11
lines changed

git_theta/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
lsh,
88
metadata,
99
params,
10+
scripts,
1011
theta,
1112
updates,
1213
utils,

git_theta/checkpoints/pickled_dict_checkpoint.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from git_theta.checkpoints import Checkpoint
88

99

10+
# TODO: We should rename this back to be Torch related as we do things like check if they are Torch.tensors.
1011
class PickledDictCheckpoint(Checkpoint):
1112
"""Class for wrapping picked dict checkpoints, commonly used with PyTorch."""
1213

@@ -29,7 +30,9 @@ def load(cls, checkpoint_path):
2930
if isinstance(checkpoint_path, io.IOBase):
3031
checkpoint_path = io.BytesIO(checkpoint_path.read())
3132

32-
model_dict = torch.load(checkpoint_path)
33+
# Map all values to the CPU as they may bave been saved to the GPU and we don't
34+
# know if the same GPU topology is available now.
35+
model_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
3336
if not isinstance(model_dict, dict):
3437
raise ValueError("Supplied PyTorch checkpoint must be a dict.")
3538
if not all(isinstance(k, str) for k in model_dict.keys()):
@@ -40,7 +43,9 @@ def load(cls, checkpoint_path):
4043

4144
@classmethod
4245
def from_framework(cls, model_dict):
43-
return cls({k: v.cpu().numpy() for k, v in model_dict.items()})
46+
# If things were saved with gradient requirements we need to detach them
47+
# before converting them to numpy arrays.
48+
return cls({k: v.cpu().detach().numpy() for k, v in model_dict.items()})
4449

4550
def to_framework(self):
4651
return {k: torch.as_tensor(v) for k, v in self.items()}

git_theta/filters.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ def clean(
2929
update_serializer, EnvVarConstants.UPDATE_DATA_PATH
3030
)
3131
prev_metadata = metadata.Metadata.from_commit(repo, path, "HEAD").flatten()
32+
logger = logging.getLogger("git_theta")
3233

3334
async def _clean(param_keys, new_param):
34-
logger = logging.getLogger("git_theta")
3535
logger.debug(f"Cleaning {'/'.join(param_keys)}")
3636
# Get the metadata from the previous version of the parameter
3737
param_metadata = prev_metadata.get(param_keys)
@@ -104,10 +104,29 @@ async def _clean(param_keys, new_param):
104104
theta_metadata=new_theta_metadata,
105105
)
106106
logger.debug(f"Finished Cleaning {'/'.join(param_keys)}")
107+
del new_param
107108
return param_keys, new_param_metadata
108109

109110
# Sort the keys so we don't get changing diffs based on serialization order.
110111
sorted_checkpoint = dict(sorted(checkpoint.flatten().items()))
112+
if EnvVarConstants.LOW_MEMORY:
113+
# Run one at a time and delete the old values as you go
114+
# TODO: Is is possible/better to process the keys based on the size
115+
# of the tensor and resort later? Then you could do things like delete
116+
# all the small ones before you have to process the large one.
117+
logger.warning(
118+
"Runing Git-Theta in Low Memory Mode, no concurrency will be used, and references to parameter weights will be freed after use."
119+
)
120+
meta = {}
121+
for k in list(sorted_checkpoint.keys()):
122+
# Get the param while removing it from the dict, removing the
123+
# reference in the dict will allow the tensor to be gc'd
124+
v = sorted_checkpoint.pop(k)
125+
param_name, param_meta = async_utils.run(_clean(k, v))
126+
meta[param_name] = param_meta
127+
# Drop the reference to the value to allow it to be gc'd.
128+
del v
129+
return metadata.Metadata(meta).unflatten()
111130
return metadata.Metadata(
112131
**async_utils.run(
113132
async_utils.run_map(

git_theta/git_utils.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -351,11 +351,8 @@ def get_file_version(repo, path: str, commit_hash_or_tag: Union[str, git.Commit]
351351
# GitPython can take commit sha1's or tags (or commit objects) here and
352352
# it gives the same results.
353353
tree = repo.commit(commit_hash_or_tag).tree
354-
if path in tree:
355-
return tree[path]
356-
else:
357-
return None
358-
except git.BadName:
354+
return tree[path]
355+
except (git.BadName, KeyError):
359356
return None
360357

361358

File renamed without changes.

git_theta/scripts/git_theta_filter.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import logging
55
import os
66
import sys
7-
import tempfile
87

98
import git_theta
109
from git_theta import checkpoints, git_utils, metadata
@@ -39,7 +38,28 @@ def run_clean(args):
3938
logger.debug(f"Running clean filter on {args.file}")
4039
repo = git_utils.get_git_repo()
4140
checkpoint_handler = checkpoints.get_checkpoint_handler()
42-
model_checkpoint = checkpoint_handler.from_file(sys.stdin.buffer)
41+
if EnvVarConstants.LOW_MEMORY:
42+
logger.warning(
43+
"Running Git-Theta in low memory mode. No concurrency is supported and the original checkout value will be transiently stored in a temporary file."
44+
)
45+
temp_file = f".{args.file}-temp-checkpoint"
46+
try:
47+
# In some places we don't have enough space when you write to the
48+
# tempfile location.
49+
logger.debug(f"Writing checkpoint to {temp_file}")
50+
with open(temp_file, "w+b") as tmp:
51+
tmp.write(sys.stdin.buffer.read())
52+
logger.debug(f"Reading checkpoint from {temp_file}")
53+
# We write and then seek instead of write,close,open because this was
54+
# originally written to use the tempfile lib, but there were space
55+
# issues. We keep that paradigm as we may switch back eventually,
56+
tmp.seek(0)
57+
model_checkpoint = checkpoint_handler.from_file(tmp)
58+
finally:
59+
# Make sure we always remove the temp checkpoint file.
60+
os.remove(temp_file)
61+
else:
62+
model_checkpoint = checkpoint_handler.from_file(sys.stdin.buffer)
4363
new_metadata = clean(model_checkpoint, repo, args.file)
4464
new_metadata.write(sys.stdout)
4565
# If we had side-loaded information, write it out so we don't get false

git_theta/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class EnvVarConstants:
8282
MAX_CONCURRENCY = EnvVar(name="GIT_THETA_MAX_CONCURRENCY", default=-1)
8383
MANUAL_MERGE = EnvVar(name="GIT_THETA_MANUAL_MERGE", default=False)
8484
LOG_LEVEL = EnvVar(name="GIT_THETA_LOG_LEVEL", default="DEBUG")
85+
LOW_MEMORY = EnvVar(name="GIT_THETA_LOW_MEMORY", default=False)
8586

8687

8788
def flatten(

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def get_version(file_name: str, version_variable: str = "__version__") -> str:
108108
},
109109
entry_points={
110110
"console_scripts": [
111-
"git-theta = git_theta.scripts.git_theta:main",
111+
"git-theta = git_theta.scripts.git_theta_cli:main",
112112
"git-theta-filter = git_theta.scripts.git_theta_filter:main",
113113
"git-theta-merge = git_theta.scripts.git_theta_merge:main",
114114
"git-theta-diff = git_theta.scripts.git_theta_diff:main",

0 commit comments

Comments
 (0)