Skip to content

Commit 14b450d

Browse files
committed
Git Theta Low Memory Mode
1 parent 0e030d2 commit 14b450d

File tree

8 files changed

+53
-10
lines changed

8 files changed

+53
-10
lines changed

git_theta/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
lsh,
88
metadata,
99
params,
10+
scripts,
1011
theta,
1112
updates,
1213
utils,

git_theta/checkpoints/pickled_dict_checkpoint.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from git_theta.checkpoints import Checkpoint
88

99

10+
# TODO: We should rename this back to be Torch related as we do things like check if they are Torch.tensors.
1011
class PickledDictCheckpoint(Checkpoint):
1112
"""Class for wrapping picked dict checkpoints, commonly used with PyTorch."""
1213

@@ -29,7 +30,9 @@ def load(cls, checkpoint_path):
2930
if isinstance(checkpoint_path, io.IOBase):
3031
checkpoint_path = io.BytesIO(checkpoint_path.read())
3132

32-
model_dict = torch.load(checkpoint_path)
33+
# Map all values to the CPU as they may bave been saved to the GPU and we don't
34+
# know if the same GPU topology is available now.
35+
model_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
3336
if not isinstance(model_dict, dict):
3437
raise ValueError("Supplied PyTorch checkpoint must be a dict.")
3538
if not all(isinstance(k, str) for k in model_dict.keys()):
@@ -40,7 +43,9 @@ def load(cls, checkpoint_path):
4043

4144
@classmethod
4245
def from_framework(cls, model_dict):
43-
return cls({k: v.cpu().numpy() for k, v in model_dict.items()})
46+
# If things were saved with gradient requirements we need to detach them
47+
# before converting them to numpy arrays.
48+
return cls({k: v.cpu().detach().numpy() for k, v in model_dict.items()})
4449

4550
def to_framework(self):
4651
return {k: torch.as_tensor(v) for k, v in self.items()}

git_theta/filters.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,29 @@ async def _clean(param_keys, new_param):
103103
theta_metadata=new_theta_metadata,
104104
)
105105
logging.debug(f"Finished Cleaning {'/'.join(param_keys)}")
106+
del new_param
106107
return param_keys, new_param_metadata
107108

108109
# Sort the keys so we don't get changing diffs based on serialization order.
109110
sorted_checkpoint = dict(sorted(checkpoint.flatten().items()))
111+
if EnvVarConstants.LOW_MEMORY:
112+
# Run one at a time and delete the old values as you go
113+
# TODO: Is is possible/better to process the keys based on the size
114+
# of the tensor and resort later? Then you could do things like delete
115+
# all the small ones before you have to process the large one.
116+
logging.warning(
117+
"Runing Git-Theta in Low Memory Mode, no concurrency will be used, and references to parameter weights will be freed after use."
118+
)
119+
meta = {}
120+
for k in list(sorted_checkpoint.keys()):
121+
# Get the param while removing it from the dict, removing the
122+
# reference there to allow the tensor to be gc'd
123+
v = sorted_checkpoint.pop(k)
124+
param_name, param_meta = async_utils.run(_clean(k, v))
125+
meta[param_name] = param_meta
126+
# Drop the reference to the value to allow it to be gc'd.
127+
del v
128+
return metadata.Metadata(meta).unflatten()
110129
return metadata.Metadata(
111130
**async_utils.run(
112131
async_utils.run_map(

git_theta/git_utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,11 +342,8 @@ def get_file_version(repo, path: str, commit_hash_or_tag: Union[str, git.Commit]
342342
# GitPython can take commit sha1's or tags (or commit objects) here and
343343
# it gives the same results.
344344
tree = repo.commit(commit_hash_or_tag).tree
345-
if path in tree:
346-
return tree[path]
347-
else:
348-
return None
349-
except git.BadName:
345+
return tree[path]
346+
except (git.BadName, KeyError):
350347
return None
351348

352349

File renamed without changes.

git_theta/scripts/git_theta_filter.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import logging
55
import os
66
import sys
7-
import tempfile
87

98
import git_theta
109
from git_theta import checkpoints, git_utils, metadata
@@ -38,7 +37,28 @@ def run_clean(args):
3837
logging.debug(f"Running clean filter on {args.file}")
3938
repo = git_utils.get_git_repo()
4039
checkpoint_handler = checkpoints.get_checkpoint_handler()
41-
model_checkpoint = checkpoint_handler.from_file(sys.stdin.buffer)
40+
if EnvVarConstants.LOW_MEMORY:
41+
logging.warning(
42+
"Running Git-Theta is low memory mode. No concurrency is supported and the original checkout value will be transiently stored in a temporary file."
43+
)
44+
temp_file = f".{args.file}-temp-checkpoint"
45+
try:
46+
# In some places we don't have enough space when you write to the
47+
# tempfile location.
48+
logging.debug(f"Writing checkpoint to {temp_file}")
49+
with open(temp_file, "w+b") as tmp:
50+
tmp.write(sys.stdin.buffer.read())
51+
logging.debug(f"Reading checkpoint from {temp_file}")
52+
# We write and then seek instead of write,close,open because this was
53+
# originally written to use the tempfile lib, but there were space
54+
# issues. We keep that paradigm as we may switch back eventually,
55+
tmp.seek(0)
56+
model_checkpoint = checkpoint_handler.from_file(tmp)
57+
finally:
58+
# Make sure we always remove the temp checkpoint file.
59+
os.remove(temp_file)
60+
else:
61+
model_checkpoint = checkpoint_handler.from_file(sys.stdin.buffer)
4262
new_metadata = clean(model_checkpoint, repo, args.file)
4363
new_metadata.write(sys.stdout)
4464
# If we had side-loaded information, write it out so we don't get false

git_theta/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class EnvVarConstants:
8282
MAX_CONCURRENCY = EnvVar(name="GIT_THETA_MAX_CONCURRENCY", default=-1)
8383
MANUAL_MERGE = EnvVar(name="GIT_THETA_MANUAL_MERGE", default=False)
8484
LOG_LEVEL = EnvVar(name="GIT_THETA_LOG_LEVEL", default="DEBUG")
85+
LOW_MEMORY = EnvVar(name="GIT_THETA_LOW_MEMORY", default=False)
8586

8687

8788
def flatten(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def get_version(file_name: str, version_variable: str = "__version__") -> str:
107107
},
108108
entry_points={
109109
"console_scripts": [
110-
"git-theta = git_theta.scripts.git_theta:main",
110+
"git-theta = git_theta.scripts.git_theta_cli:main",
111111
"git-theta-filter = git_theta.scripts.git_theta_filter:main",
112112
"git-theta-merge = git_theta.scripts.git_theta_merge:main",
113113
"git-theta-diff = git_theta.scripts.git_theta_diff:main",

0 commit comments

Comments
 (0)