Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Git Theta Low Memory Mode #234

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions git_theta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
lsh,
metadata,
params,
scripts,
theta,
updates,
utils,
Expand Down
9 changes: 7 additions & 2 deletions git_theta/checkpoints/pickled_dict_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from git_theta.checkpoints import Checkpoint


# TODO: We should rename this back to be Torch related as we do things like check if they are Torch.tensors.
class PickledDictCheckpoint(Checkpoint):
"""Class for wrapping picked dict checkpoints, commonly used with PyTorch."""

Expand All @@ -29,7 +30,9 @@ def load(cls, checkpoint_path):
if isinstance(checkpoint_path, io.IOBase):
checkpoint_path = io.BytesIO(checkpoint_path.read())

model_dict = torch.load(checkpoint_path)
# Map all values to the CPU as they may bave been saved to the GPU and we don't
# know if the same GPU topology is available now.
model_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
if not isinstance(model_dict, dict):
raise ValueError("Supplied PyTorch checkpoint must be a dict.")
if not all(isinstance(k, str) for k in model_dict.keys()):
Expand All @@ -40,7 +43,9 @@ def load(cls, checkpoint_path):

@classmethod
def from_framework(cls, model_dict):
return cls({k: v.cpu().numpy() for k, v in model_dict.items()})
# If things were saved with gradient requirements we need to detach them
# before converting them to numpy arrays.
return cls({k: v.cpu().detach().numpy() for k, v in model_dict.items()})

def to_framework(self):
return {k: torch.as_tensor(v) for k, v in self.items()}
Expand Down
21 changes: 20 additions & 1 deletion git_theta/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def clean(
update_serializer, EnvVarConstants.UPDATE_DATA_PATH
)
prev_metadata = metadata.Metadata.from_commit(repo, path, "HEAD").flatten()
logger = logging.getLogger("git_theta")

async def _clean(param_keys, new_param):
logger = logging.getLogger("git_theta")
logger.debug(f"Cleaning {'/'.join(param_keys)}")
# Get the metadata from the previous version of the parameter
param_metadata = prev_metadata.get(param_keys)
Expand Down Expand Up @@ -104,10 +104,29 @@ async def _clean(param_keys, new_param):
theta_metadata=new_theta_metadata,
)
logger.debug(f"Finished Cleaning {'/'.join(param_keys)}")
del new_param
return param_keys, new_param_metadata

# Sort the keys so we don't get changing diffs based on serialization order.
sorted_checkpoint = dict(sorted(checkpoint.flatten().items()))
if EnvVarConstants.LOW_MEMORY:
# Run one at a time and delete the old values as you go
# TODO: Is is possible/better to process the keys based on the size
# of the tensor and resort later? Then you could do things like delete
# all the small ones before you have to process the large one.
logger.warning(
"Runing Git-Theta in Low Memory Mode, no concurrency will be used, and references to parameter weights will be freed after use."
)
meta = {}
for k in list(sorted_checkpoint.keys()):
# Get the param while removing it from the dict, removing the
# reference in the dict will allow the tensor to be gc'd
v = sorted_checkpoint.pop(k)
param_name, param_meta = async_utils.run(_clean(k, v))
meta[param_name] = param_meta
# Drop the reference to the value to allow it to be gc'd.
del v
return metadata.Metadata(meta).unflatten()
return metadata.Metadata(
**async_utils.run(
async_utils.run_map(
Expand Down
7 changes: 2 additions & 5 deletions git_theta/git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,8 @@ def get_file_version(repo, path: str, commit_hash_or_tag: Union[str, git.Commit]
# GitPython can take commit sha1's or tags (or commit objects) here and
# it gives the same results.
tree = repo.commit(commit_hash_or_tag).tree
if path in tree:
return tree[path]
else:
return None
except git.BadName:
return tree[path]
except (git.BadName, KeyError):
return None


Expand Down
File renamed without changes.
24 changes: 22 additions & 2 deletions git_theta/scripts/git_theta_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import os
import sys
import tempfile

import git_theta
from git_theta import checkpoints, git_utils, metadata
Expand Down Expand Up @@ -39,7 +38,28 @@ def run_clean(args):
logger.debug(f"Running clean filter on {args.file}")
repo = git_utils.get_git_repo()
checkpoint_handler = checkpoints.get_checkpoint_handler()
model_checkpoint = checkpoint_handler.from_file(sys.stdin.buffer)
if EnvVarConstants.LOW_MEMORY:
logger.warning(
"Running Git-Theta in low memory mode. No concurrency is supported and the original checkout value will be transiently stored in a temporary file."
)
temp_file = f".{args.file}-temp-checkpoint"
try:
# In some places we don't have enough space when you write to the
# tempfile location.
logger.debug(f"Writing checkpoint to {temp_file}")
with open(temp_file, "w+b") as tmp:
tmp.write(sys.stdin.buffer.read())
logger.debug(f"Reading checkpoint from {temp_file}")
# We write and then seek instead of write,close,open because this was
# originally written to use the tempfile lib, but there were space
# issues. We keep that paradigm as we may switch back eventually,
tmp.seek(0)
model_checkpoint = checkpoint_handler.from_file(tmp)
finally:
# Make sure we always remove the temp checkpoint file.
os.remove(temp_file)
else:
model_checkpoint = checkpoint_handler.from_file(sys.stdin.buffer)
new_metadata = clean(model_checkpoint, repo, args.file)
new_metadata.write(sys.stdout)
# If we had side-loaded information, write it out so we don't get false
Expand Down
1 change: 1 addition & 0 deletions git_theta/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class EnvVarConstants:
MAX_CONCURRENCY = EnvVar(name="GIT_THETA_MAX_CONCURRENCY", default=-1)
MANUAL_MERGE = EnvVar(name="GIT_THETA_MANUAL_MERGE", default=False)
LOG_LEVEL = EnvVar(name="GIT_THETA_LOG_LEVEL", default="DEBUG")
LOW_MEMORY = EnvVar(name="GIT_THETA_LOW_MEMORY", default=False)


def flatten(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_version(file_name: str, version_variable: str = "__version__") -> str:
},
entry_points={
"console_scripts": [
"git-theta = git_theta.scripts.git_theta:main",
"git-theta = git_theta.scripts.git_theta_cli:main",
"git-theta-filter = git_theta.scripts.git_theta_filter:main",
"git-theta-merge = git_theta.scripts.git_theta_merge:main",
"git-theta-diff = git_theta.scripts.git_theta_diff:main",
Expand Down
Loading