Skip to content

Commit 0569cf9

Browse files
committed
Implement the ability to infer the checkpoint type
Currently this is just based on the file path. In order to make the sniffing extensible, it is implemented as a plugin. In order to do that, and to avoid having to import a deep learning framework to do sniffing, the sniffers are separated out into their own files. This lead to the question of why are checkpoint plugs stored in the main repo? If we move them to the actual plugins area we can continue to support them while not needing to install them if a framework isn't going to be used. Now when you want to git theta with a specific framework you can either run `pip install git-theta[framework]` or run `pip install git-theta-checkpoints-framework` for your framework.
1 parent d3c288d commit 0569cf9

File tree

22 files changed

+386
-18
lines changed

22 files changed

+386
-18
lines changed

git_theta/checkpoints/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44
Checkpoint,
55
get_checkpoint_handler,
66
get_checkpoint_handler_name,
7+
sniff_checkpoint,
78
)

git_theta/checkpoints/base.py

+18
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,21 @@ def get_checkpoint_handler(checkpoint_type: Optional[str] = None) -> Checkpoint:
157157
checkpoint_type = get_checkpoint_handler_name(checkpoint_type)
158158
discovered_plugins = entry_points(group="git_theta.plugins.checkpoints")
159159
return discovered_plugins[checkpoint_type].load()
160+
161+
162+
def sniff_checkpoint(checkpoint_path) -> str:
163+
"""En"""
164+
discovered_plugins = entry_points(group="git_theta.plugins.checkpoint.sniffers")
165+
loaded_plugins = {ep.name: ep.load() for ep in discovered_plugins}
166+
logger = logging.getLogger("git_theta")
167+
logger.debug(
168+
f"Sniffing {checkpoint_path} to infer which deep learning framework it is."
169+
)
170+
for ckpt_type, ckpt_sniffer in loaded_plugins.items():
171+
logger.debug(f"Checking if {checkpoint_path} is a {ckpt_type} checkpoint.")
172+
if ckpt_sniffer(checkpoint_path):
173+
logger.debug(
174+
f"Determined that {checkpoint_path} is a {ckpt_type} checkpoint."
175+
)
176+
return ckpt_type
177+
raise ValueError(f"Couldn't determine checkpoint type for {checkpoint_path}")

git_theta/scripts/git_theta_filter.py

+4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def run_clean(args):
3636
"""
3737
logger = logging.getLogger("git_theta")
3838
logger.debug(f"Running clean filter on {args.file}")
39+
if EnvVarConstants.CHECKPOINT_TYPE == "sniff":
40+
EnvVarConstants.CHECKPOINT_TYPE = checkpoints.sniff_checkpoints()
3941
repo = git_utils.get_git_repo()
4042
checkpoint_handler = checkpoints.get_checkpoint_handler()
4143
if EnvVarConstants.LOW_MEMORY:
@@ -74,6 +76,8 @@ def run_smudge(args):
7476
"""
7577
logger = logging.getLogger("git_theta")
7678
logger.debug(f"Running smudge filter on {args.file}")
79+
if EnvVarConstants.CHECKPOINT_TYPE == "sniff":
80+
EnvVarConstants.CHECKPOINT_TYPE = checkpoints.sniff_checkpoints()
7781

7882
repo = git_utils.get_git_repo()
7983
curr_metadata = metadata.Metadata.from_file(sys.stdin)

git_theta/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __get__(self, obj, objtype=None):
7171

7272

7373
class EnvVarConstants:
74-
CHECKPOINT_TYPE = EnvVar(name="GIT_THETA_CHECKPOINT_TYPE", default="pytorch")
74+
CHECKPOINT_TYPE = EnvVar(name="GIT_THETA_CHECKPOINT_TYPE", default="sniff")
7575
UPDATE_TYPE = EnvVar(name="GIT_THETA_UPDATE_TYPE", default="dense")
7676
UPDATE_DATA_PATH = EnvVar(name="GIT_THETA_UPDATE_DATA_PATH", default="")
7777
PARAMETER_ATOL = EnvVar(name="GIT_THETA_PARAMETER_ATOL", default=1e-8)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = "0.2.0"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""Infer if a checkpoint is flax based.
2+
3+
We put this in a different file to avoid importing dl frameworks for file sniffing.
4+
"""
5+
6+
7+
def flax_sniffer(checkpoint_path: str) -> bool:
8+
# TODO: Check if the actual value is msgpack based on magic numbers?
9+
return checkpoint_path.endswith(".flax")

plugins/checkpoints/flax/setup.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Plugin to support the flax checkpoint format."""
2+
3+
import ast
4+
import os
5+
6+
from setuptools import setup
7+
8+
9+
def get_version(file_name: str, version_variable: str = "__version__") -> str:
10+
"""Find the version by walking the AST to avoid duplication.
11+
12+
Parameters
13+
----------
14+
file_name : str
15+
The file we are parsing to get the version string from.
16+
version_variable : str
17+
The variable name that holds the version string.
18+
19+
Raises
20+
------
21+
ValueError
22+
If there was no assignment to version_variable in file_name.
23+
24+
Returns
25+
-------
26+
version_string : str
27+
The version string parsed from file_name_name.
28+
"""
29+
with open(file_name) as f:
30+
tree = ast.parse(f.read())
31+
# Look at all assignment nodes that happen in the ast. If the variable
32+
# name matches the given parameter, grab the value (which will be
33+
# the version string we are looking for).
34+
for node in ast.walk(tree):
35+
if isinstance(node, ast.Assign):
36+
if node.targets[0].id == version_variable:
37+
return node.value.s
38+
raise ValueError(
39+
f"Could not find an assignment to {version_variable} " f"within '{file_name}'"
40+
)
41+
42+
43+
setup(
44+
name="git_theta_checkpoints_flax",
45+
description="Plugin to support the flax checkpoint format.",
46+
install_requires=[
47+
# "git_theta",
48+
"flax",
49+
"jax",
50+
],
51+
version=get_version("git_theta_checkpoints_flax/__init__.py"),
52+
packages=[
53+
"git_theta_checkpoints_flax",
54+
],
55+
author="Brian Lester",
56+
entry_points={
57+
"git_theta.plugins.checkpoints": [
58+
"flax = git_theta_checkpoints_flax.checkpoints:FlaxCheckpoint",
59+
"flax-checkpoint = git_theta_checkpoints_flax.checkpoints:FlaxCheckpoint",
60+
],
61+
"git_theta.plugins.checkpoint.sniffers": [
62+
"flax = git_theta_checkpoints_flax.sniffer:flax_sniffer",
63+
],
64+
},
65+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = "0.2.0"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Infer if a checkpoint is pytorch based.
2+
3+
We put this in a different file to avoid importing dl frameworks for file sniffing.
4+
"""
5+
6+
import re
7+
8+
9+
def pytorch_sniffer(checkpoint_path: str) -> bool:
10+
# Many checkpoints on HuggingFace Hub are named this.
11+
if checkpoint_path == "pytorch_model.bin":
12+
return True
13+
if re.search(r"\.py?t$", checkpoint_path):
14+
return True
15+
return False

plugins/checkpoints/pytorch/setup.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Plugin to support the safetensor format."""
2+
3+
import ast
4+
import os
5+
6+
from setuptools import setup
7+
8+
9+
def get_version(file_name: str, version_variable: str = "__version__") -> str:
10+
"""Find the version by walking the AST to avoid duplication.
11+
12+
Parameters
13+
----------
14+
file_name : str
15+
The file we are parsing to get the version string from.
16+
version_variable : str
17+
The variable name that holds the version string.
18+
19+
Raises
20+
------
21+
ValueError
22+
If there was no assignment to version_variable in file_name.
23+
24+
Returns
25+
-------
26+
version_string : str
27+
The version string parsed from file_name_name.
28+
"""
29+
with open(file_name) as f:
30+
tree = ast.parse(f.read())
31+
# Look at all assignment nodes that happen in the ast. If the variable
32+
# name matches the given parameter, grab the value (which will be
33+
# the version string we are looking for).
34+
for node in ast.walk(tree):
35+
if isinstance(node, ast.Assign):
36+
if node.targets[0].id == version_variable:
37+
return node.value.s
38+
raise ValueError(
39+
f"Could not find an assignment to {version_variable} " f"within '{file_name}'"
40+
)
41+
42+
43+
setup(
44+
name="git_theta_checkpoints_pytorch",
45+
description="Plugin to support the pytorch checkpoint format.",
46+
install_requires=[
47+
# "git_theta",
48+
"torch",
49+
],
50+
version=get_version("git_theta_checkpoints_pytorch/__init__.py"),
51+
packages=[
52+
"git_theta_checkpoints_pytorch",
53+
],
54+
author="Brian Lester",
55+
entry_points={
56+
"git_theta.plugins.checkpoints": [
57+
"pytorch = git_theta_checkpoints_pytorch.checkpoints:PickledDictCheckpoint",
58+
"pickled-dict = git_theta_checkpoints_pytorch.checkpoints:PickledDictCheckpoint",
59+
],
60+
"git_theta.plugins.checkpoint.sniffers": [
61+
"pytorch = git_theta_checkpoints_pytorch.sniffer:pytorch_sniffer",
62+
],
63+
},
64+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = "0.2.0"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Checkpoint using the HF safetensors format.
2+
3+
safetensors has the ability to write model checkpoint from "dl-native" -> "safetensors"
4+
and read "safetensors" -> any "dl-native" framework, not just the one that was
5+
used to write it. Therefore, we read/write with their numpy API.
6+
"""
7+
8+
import safetensors.numpy
9+
from file_or_name import file_or_name
10+
11+
from git_theta.checkpoints import Checkpoint
12+
13+
14+
# TODO(bdlester): Can we leverage the lazying loading ability to make things faster?
15+
class SafeTensorsCheckpoint(Checkpoint):
16+
"""Class for r/w of the safetensors format. https://github.com/huggingface/safetensors"""
17+
18+
name: str = "safetensors"
19+
20+
@classmethod
21+
@file_or_name(checkpoint_path="rb")
22+
def load(cls, checkpoint_path: str):
23+
# Note that we use the numpy as the framework because we don't care what
24+
# their downstream dl framework is, we only want the results back as
25+
# numpy arrays.
26+
return safetensors.numpy.load(checkpoint_path.read())
27+
28+
@file_or_name(checkpoint_path="wb")
29+
def save(self, checkpoint_path: str):
30+
# Note, git theta uses numpy internally, so we save using the numpy api,
31+
# regardless of the original framework they used to write the checkpoint.
32+
checkpoint_dict = self.to_framework()
33+
checkpoint_path.write(safetensors.numpy.save(checkpoint_dict))
34+
35+
def to_framework(self):
36+
return self
37+
38+
@classmethod
39+
def from_framework(cls, model_dict):
40+
return cls(model_dict)
41+
42+
43+
def safetensors_sniffer(checkpoint_path: str) -> bool:
44+
return checkpoint_path.endswith(".safetensors")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Infer if a checkpoint is safetensors based.
2+
3+
We put this in a different file to avoid importing dl frameworks for file sniffing.
4+
"""
5+
6+
7+
def safetensors_sniffer(checkpoint_path: str) -> bool:
8+
return checkpoint_path.endswith(".safetensors")
+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Plugin to support the safetensor format."""
2+
3+
import ast
4+
import os
5+
6+
from setuptools import setup
7+
8+
9+
def get_version(file_name: str, version_variable: str = "__version__") -> str:
10+
"""Find the version by walking the AST to avoid duplication.
11+
12+
Parameters
13+
----------
14+
file_name : str
15+
The file we are parsing to get the version string from.
16+
version_variable : str
17+
The variable name that holds the version string.
18+
19+
Raises
20+
------
21+
ValueError
22+
If there was no assignment to version_variable in file_name.
23+
24+
Returns
25+
-------
26+
version_string : str
27+
The version string parsed from file_name_name.
28+
"""
29+
with open(file_name) as f:
30+
tree = ast.parse(f.read())
31+
# Look at all assignment nodes that happen in the ast. If the variable
32+
# name matches the given parameter, grab the value (which will be
33+
# the version string we are looking for).
34+
for node in ast.walk(tree):
35+
if isinstance(node, ast.Assign):
36+
if node.targets[0].id == version_variable:
37+
return node.value.s
38+
raise ValueError(
39+
f"Could not find an assignment to {version_variable} " f"within '{file_name}'"
40+
)
41+
42+
43+
setup(
44+
name="git_theta_checkpoints_safetensors",
45+
description="Plugin to support the safetensors checkpoint format.",
46+
install_requires=[
47+
# "git_theta",
48+
"safetensors",
49+
],
50+
version=get_version("git_theta_checkpoints_safetensors/__init__.py"),
51+
packages=[
52+
"git_theta_checkpoints_safetensors",
53+
],
54+
author="Brian Lester",
55+
entry_points={
56+
"git_theta.plugins.checkpoints": [
57+
"safetensors = git_theta_checkpoints_safetensors.checkpoints:SafeTensorsCheckpoint",
58+
"safetensors-checkpoint = git_theta_checkpoints_safetensors.checkpoints:SafeTensorsCheckpoint",
59+
],
60+
"git_theta.plugins.checkpoint.sniffers": [
61+
"safetensors = git_theta_checkpoints_safetensors.sniffer:safetensors_sniffer",
62+
],
63+
},
64+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = "0.2.0"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""Infer if a checkpoint is tensorflow based.
2+
3+
We put this in a different file to avoid importing dl frameworks for file sniffing.
4+
"""
5+
6+
7+
def tensorflow_sniffer(checkpoint_path: str) -> bool:
8+
return checkpoint_path.endswith(".tf")
9+
10+
11+
# TODO: Add support for detecting saved models.
12+
def saved_model_sniffer(checkpoint_path: str) -> bool:
13+
# We don't support saved models yet.
14+
return False

0 commit comments

Comments
 (0)