-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement the ability to infer the checkpoint type #254
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,5 @@ | |
Checkpoint, | ||
get_checkpoint_handler, | ||
get_checkpoint_handler_name, | ||
sniff_checkpoint, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -157,3 +157,21 @@ def get_checkpoint_handler(checkpoint_type: Optional[str] = None) -> Checkpoint: | |
checkpoint_type = get_checkpoint_handler_name(checkpoint_type) | ||
discovered_plugins = entry_points(group="git_theta.plugins.checkpoints") | ||
return discovered_plugins[checkpoint_type].load() | ||
|
||
|
||
def sniff_checkpoint(checkpoint_path) -> str: | ||
"""En""" | ||
discovered_plugins = entry_points(group="git_theta.plugins.checkpoint.sniffers") | ||
loaded_plugins = {ep.name: ep.load() for ep in discovered_plugins} | ||
logger = logging.getLogger("git_theta") | ||
logger.debug( | ||
f"Sniffing {checkpoint_path} to infer which deep learning framework it is." | ||
) | ||
for ckpt_type, ckpt_sniffer in loaded_plugins.items(): | ||
logger.debug(f"Checking if {checkpoint_path} is a {ckpt_type} checkpoint.") | ||
if ckpt_sniffer(checkpoint_path): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we have any kind of error handling in the case where two checkpoint sniffers denote that a given checkpoint is their type? Currently this will return the checkpoint type corresponding to the first checkpoint sniffer the loop encounters that matches (which I assume is in arbitrary order). |
||
logger.debug( | ||
f"Determined that {checkpoint_path} is a {ckpt_type} checkpoint." | ||
) | ||
return ckpt_type | ||
raise ValueError(f"Couldn't determine checkpoint type for {checkpoint_path}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__version__ = "0.2.0" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
"""Infer if a checkpoint is flax based. | ||
|
||
We put this in a different file to avoid importing dl frameworks for file sniffing. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a relevant comment for this particular sniffer? |
||
""" | ||
|
||
|
||
def flax_sniffer(checkpoint_path: str) -> bool: | ||
# TODO: Check if the actual value is msgpack based on magic numbers? | ||
return checkpoint_path.endswith(".flax") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
"""Plugin to support the flax checkpoint format.""" | ||
|
||
import ast | ||
import os | ||
|
||
from setuptools import setup | ||
|
||
|
||
def get_version(file_name: str, version_variable: str = "__version__") -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does not seem to be specific to the flax checkpoint sniffer and is duplicated across the setup.py files - should it be pulled out into a more general utility function filea |
||
"""Find the version by walking the AST to avoid duplication. | ||
|
||
Parameters | ||
---------- | ||
file_name : str | ||
The file we are parsing to get the version string from. | ||
version_variable : str | ||
The variable name that holds the version string. | ||
|
||
Raises | ||
------ | ||
ValueError | ||
If there was no assignment to version_variable in file_name. | ||
|
||
Returns | ||
------- | ||
version_string : str | ||
The version string parsed from file_name_name. | ||
""" | ||
with open(file_name) as f: | ||
tree = ast.parse(f.read()) | ||
# Look at all assignment nodes that happen in the ast. If the variable | ||
# name matches the given parameter, grab the value (which will be | ||
# the version string we are looking for). | ||
for node in ast.walk(tree): | ||
if isinstance(node, ast.Assign): | ||
if node.targets[0].id == version_variable: | ||
return node.value.s | ||
raise ValueError( | ||
f"Could not find an assignment to {version_variable} " f"within '{file_name}'" | ||
) | ||
|
||
|
||
setup( | ||
name="git_theta_checkpoints_flax", | ||
description="Plugin to support the flax checkpoint format.", | ||
install_requires=[ | ||
# "git_theta", | ||
"flax", | ||
"jax", | ||
], | ||
version=get_version("git_theta_checkpoints_flax/__init__.py"), | ||
packages=[ | ||
"git_theta_checkpoints_flax", | ||
], | ||
author="Brian Lester", | ||
entry_points={ | ||
"git_theta.plugins.checkpoints": [ | ||
"flax = git_theta_checkpoints_flax.checkpoints:FlaxCheckpoint", | ||
"flax-checkpoint = git_theta_checkpoints_flax.checkpoints:FlaxCheckpoint", | ||
], | ||
"git_theta.plugins.checkpoint.sniffers": [ | ||
"flax = git_theta_checkpoints_flax.sniffer:flax_sniffer", | ||
], | ||
}, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__version__ = "0.2.0" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
"""Infer if a checkpoint is pytorch based. | ||
|
||
We put this in a different file to avoid importing dl frameworks for file sniffing. | ||
""" | ||
|
||
import re | ||
|
||
|
||
def pytorch_sniffer(checkpoint_path: str) -> bool: | ||
# Many checkpoints on HuggingFace Hub are named this. | ||
if checkpoint_path == "pytorch_model.bin": | ||
return True | ||
if re.search(r"\.py?t$", checkpoint_path): | ||
return True | ||
return False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
"""Plugin to support the safetensor format.""" | ||
|
||
import ast | ||
import os | ||
|
||
from setuptools import setup | ||
|
||
|
||
def get_version(file_name: str, version_variable: str = "__version__") -> str: | ||
"""Find the version by walking the AST to avoid duplication. | ||
|
||
Parameters | ||
---------- | ||
file_name : str | ||
The file we are parsing to get the version string from. | ||
version_variable : str | ||
The variable name that holds the version string. | ||
|
||
Raises | ||
------ | ||
ValueError | ||
If there was no assignment to version_variable in file_name. | ||
|
||
Returns | ||
------- | ||
version_string : str | ||
The version string parsed from file_name_name. | ||
""" | ||
with open(file_name) as f: | ||
tree = ast.parse(f.read()) | ||
# Look at all assignment nodes that happen in the ast. If the variable | ||
# name matches the given parameter, grab the value (which will be | ||
# the version string we are looking for). | ||
for node in ast.walk(tree): | ||
if isinstance(node, ast.Assign): | ||
if node.targets[0].id == version_variable: | ||
return node.value.s | ||
raise ValueError( | ||
f"Could not find an assignment to {version_variable} " f"within '{file_name}'" | ||
) | ||
|
||
|
||
setup( | ||
name="git_theta_checkpoints_pytorch", | ||
description="Plugin to support the pytorch checkpoint format.", | ||
install_requires=[ | ||
# "git_theta", | ||
"torch", | ||
], | ||
version=get_version("git_theta_checkpoints_pytorch/__init__.py"), | ||
packages=[ | ||
"git_theta_checkpoints_pytorch", | ||
], | ||
author="Brian Lester", | ||
entry_points={ | ||
"git_theta.plugins.checkpoints": [ | ||
"pytorch = git_theta_checkpoints_pytorch.checkpoints:PickledDictCheckpoint", | ||
"pickled-dict = git_theta_checkpoints_pytorch.checkpoints:PickledDictCheckpoint", | ||
], | ||
"git_theta.plugins.checkpoint.sniffers": [ | ||
"pytorch = git_theta_checkpoints_pytorch.sniffer:pytorch_sniffer", | ||
], | ||
}, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__version__ = "0.2.0" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
"""Checkpoint using the HF safetensors format. | ||
|
||
safetensors has the ability to write model checkpoint from "dl-native" -> "safetensors" | ||
and read "safetensors" -> any "dl-native" framework, not just the one that was | ||
used to write it. Therefore, we read/write with their numpy API. | ||
""" | ||
|
||
import safetensors.numpy | ||
from file_or_name import file_or_name | ||
|
||
from git_theta.checkpoints import Checkpoint | ||
|
||
|
||
# TODO(bdlester): Can we leverage the lazying loading ability to make things faster? | ||
class SafeTensorsCheckpoint(Checkpoint): | ||
"""Class for r/w of the safetensors format. https://github.com/huggingface/safetensors""" | ||
|
||
name: str = "safetensors" | ||
|
||
@classmethod | ||
@file_or_name(checkpoint_path="rb") | ||
def load(cls, checkpoint_path: str): | ||
# Note that we use the numpy as the framework because we don't care what | ||
# their downstream dl framework is, we only want the results back as | ||
# numpy arrays. | ||
return safetensors.numpy.load(checkpoint_path.read()) | ||
|
||
@file_or_name(checkpoint_path="wb") | ||
def save(self, checkpoint_path: str): | ||
# Note, git theta uses numpy internally, so we save using the numpy api, | ||
# regardless of the original framework they used to write the checkpoint. | ||
checkpoint_dict = self.to_framework() | ||
checkpoint_path.write(safetensors.numpy.save(checkpoint_dict)) | ||
|
||
def to_framework(self): | ||
return self | ||
|
||
@classmethod | ||
def from_framework(cls, model_dict): | ||
return cls(model_dict) | ||
|
||
|
||
def safetensors_sniffer(checkpoint_path: str) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this here and in the sniffer file itself? |
||
return checkpoint_path.endswith(".safetensors") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""Infer if a checkpoint is safetensors based. | ||
|
||
We put this in a different file to avoid importing dl frameworks for file sniffing. | ||
""" | ||
|
||
|
||
def safetensors_sniffer(checkpoint_path: str) -> bool: | ||
return checkpoint_path.endswith(".safetensors") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
"""Plugin to support the safetensor format.""" | ||
|
||
import ast | ||
import os | ||
|
||
from setuptools import setup | ||
|
||
|
||
def get_version(file_name: str, version_variable: str = "__version__") -> str: | ||
"""Find the version by walking the AST to avoid duplication. | ||
|
||
Parameters | ||
---------- | ||
file_name : str | ||
The file we are parsing to get the version string from. | ||
version_variable : str | ||
The variable name that holds the version string. | ||
|
||
Raises | ||
------ | ||
ValueError | ||
If there was no assignment to version_variable in file_name. | ||
|
||
Returns | ||
------- | ||
version_string : str | ||
The version string parsed from file_name_name. | ||
""" | ||
with open(file_name) as f: | ||
tree = ast.parse(f.read()) | ||
# Look at all assignment nodes that happen in the ast. If the variable | ||
# name matches the given parameter, grab the value (which will be | ||
# the version string we are looking for). | ||
for node in ast.walk(tree): | ||
if isinstance(node, ast.Assign): | ||
if node.targets[0].id == version_variable: | ||
return node.value.s | ||
raise ValueError( | ||
f"Could not find an assignment to {version_variable} " f"within '{file_name}'" | ||
) | ||
|
||
|
||
setup( | ||
name="git_theta_checkpoints_safetensors", | ||
description="Plugin to support the safetensors checkpoint format.", | ||
install_requires=[ | ||
# "git_theta", | ||
"safetensors", | ||
], | ||
version=get_version("git_theta_checkpoints_safetensors/__init__.py"), | ||
packages=[ | ||
"git_theta_checkpoints_safetensors", | ||
], | ||
author="Brian Lester", | ||
entry_points={ | ||
"git_theta.plugins.checkpoints": [ | ||
"safetensors = git_theta_checkpoints_safetensors.checkpoints:SafeTensorsCheckpoint", | ||
"safetensors-checkpoint = git_theta_checkpoints_safetensors.checkpoints:SafeTensorsCheckpoint", | ||
], | ||
"git_theta.plugins.checkpoint.sniffers": [ | ||
"safetensors = git_theta_checkpoints_safetensors.sniffer:safetensors_sniffer", | ||
], | ||
}, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__version__ = "0.2.0" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
"""Infer if a checkpoint is tensorflow based. | ||
|
||
We put this in a different file to avoid importing dl frameworks for file sniffing. | ||
""" | ||
|
||
|
||
def tensorflow_sniffer(checkpoint_path: str) -> bool: | ||
return checkpoint_path.endswith(".tf") | ||
|
||
|
||
# TODO: Add support for detecting saved models. | ||
def saved_model_sniffer(checkpoint_path: str) -> bool: | ||
# We don't support saved models yet. | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's En?