Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the ability to infer the checkpoint type #254

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions git_theta/checkpoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
Checkpoint,
get_checkpoint_handler,
get_checkpoint_handler_name,
sniff_checkpoint,
)
18 changes: 18 additions & 0 deletions git_theta/checkpoints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,21 @@ def get_checkpoint_handler(checkpoint_type: Optional[str] = None) -> Checkpoint:
checkpoint_type = get_checkpoint_handler_name(checkpoint_type)
discovered_plugins = entry_points(group="git_theta.plugins.checkpoints")
return discovered_plugins[checkpoint_type].load()


def sniff_checkpoint(checkpoint_path) -> str:
"""En"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's En?

discovered_plugins = entry_points(group="git_theta.plugins.checkpoint.sniffers")
loaded_plugins = {ep.name: ep.load() for ep in discovered_plugins}
logger = logging.getLogger("git_theta")
logger.debug(
f"Sniffing {checkpoint_path} to infer which deep learning framework it is."
)
for ckpt_type, ckpt_sniffer in loaded_plugins.items():
logger.debug(f"Checking if {checkpoint_path} is a {ckpt_type} checkpoint.")
if ckpt_sniffer(checkpoint_path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have any kind of error handling in the case where two checkpoint sniffers denote that a given checkpoint is their type? Currently this will return the checkpoint type corresponding to the first checkpoint sniffer the loop encounters that matches (which I assume is in arbitrary order).

logger.debug(
f"Determined that {checkpoint_path} is a {ckpt_type} checkpoint."
)
return ckpt_type
raise ValueError(f"Couldn't determine checkpoint type for {checkpoint_path}")
4 changes: 4 additions & 0 deletions git_theta/scripts/git_theta_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def run_clean(args):
"""
logger = logging.getLogger("git_theta")
logger.debug(f"Running clean filter on {args.file}")
if EnvVarConstants.CHECKPOINT_TYPE == "sniff":
EnvVarConstants.CHECKPOINT_TYPE = checkpoints.sniff_checkpoints()
repo = git_utils.get_git_repo()
checkpoint_handler = checkpoints.get_checkpoint_handler()
if EnvVarConstants.LOW_MEMORY:
Expand Down Expand Up @@ -74,6 +76,8 @@ def run_smudge(args):
"""
logger = logging.getLogger("git_theta")
logger.debug(f"Running smudge filter on {args.file}")
if EnvVarConstants.CHECKPOINT_TYPE == "sniff":
EnvVarConstants.CHECKPOINT_TYPE = checkpoints.sniff_checkpoints()

repo = git_utils.get_git_repo()
curr_metadata = metadata.Metadata.from_file(sys.stdin)
Expand Down
2 changes: 1 addition & 1 deletion git_theta/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __get__(self, obj, objtype=None):


class EnvVarConstants:
CHECKPOINT_TYPE = EnvVar(name="GIT_THETA_CHECKPOINT_TYPE", default="pytorch")
CHECKPOINT_TYPE = EnvVar(name="GIT_THETA_CHECKPOINT_TYPE", default="sniff")
UPDATE_TYPE = EnvVar(name="GIT_THETA_UPDATE_TYPE", default="dense")
UPDATE_DATA_PATH = EnvVar(name="GIT_THETA_UPDATE_DATA_PATH", default="")
PARAMETER_ATOL = EnvVar(name="GIT_THETA_PARAMETER_ATOL", default=1e-8)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.2.0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Infer if a checkpoint is flax based.

We put this in a different file to avoid importing dl frameworks for file sniffing.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a relevant comment for this particular sniffer?

"""


def flax_sniffer(checkpoint_path: str) -> bool:
# TODO: Check if the actual value is msgpack based on magic numbers?
return checkpoint_path.endswith(".flax")
65 changes: 65 additions & 0 deletions plugins/checkpoints/flax/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Plugin to support the flax checkpoint format."""

import ast
import os

from setuptools import setup


def get_version(file_name: str, version_variable: str = "__version__") -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not seem to be specific to the flax checkpoint sniffer and is duplicated across the setup.py files - should it be pulled out into a more general utility function filea

"""Find the version by walking the AST to avoid duplication.

Parameters
----------
file_name : str
The file we are parsing to get the version string from.
version_variable : str
The variable name that holds the version string.

Raises
------
ValueError
If there was no assignment to version_variable in file_name.

Returns
-------
version_string : str
The version string parsed from file_name_name.
"""
with open(file_name) as f:
tree = ast.parse(f.read())
# Look at all assignment nodes that happen in the ast. If the variable
# name matches the given parameter, grab the value (which will be
# the version string we are looking for).
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
if node.targets[0].id == version_variable:
return node.value.s
raise ValueError(
f"Could not find an assignment to {version_variable} " f"within '{file_name}'"
)


setup(
name="git_theta_checkpoints_flax",
description="Plugin to support the flax checkpoint format.",
install_requires=[
# "git_theta",
"flax",
"jax",
],
version=get_version("git_theta_checkpoints_flax/__init__.py"),
packages=[
"git_theta_checkpoints_flax",
],
author="Brian Lester",
entry_points={
"git_theta.plugins.checkpoints": [
"flax = git_theta_checkpoints_flax.checkpoints:FlaxCheckpoint",
"flax-checkpoint = git_theta_checkpoints_flax.checkpoints:FlaxCheckpoint",
],
"git_theta.plugins.checkpoint.sniffers": [
"flax = git_theta_checkpoints_flax.sniffer:flax_sniffer",
],
},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.2.0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Infer if a checkpoint is pytorch based.

We put this in a different file to avoid importing dl frameworks for file sniffing.
"""

import re


def pytorch_sniffer(checkpoint_path: str) -> bool:
# Many checkpoints on HuggingFace Hub are named this.
if checkpoint_path == "pytorch_model.bin":
return True
if re.search(r"\.py?t$", checkpoint_path):
return True
return False
64 changes: 64 additions & 0 deletions plugins/checkpoints/pytorch/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Plugin to support the safetensor format."""

import ast
import os

from setuptools import setup


def get_version(file_name: str, version_variable: str = "__version__") -> str:
"""Find the version by walking the AST to avoid duplication.

Parameters
----------
file_name : str
The file we are parsing to get the version string from.
version_variable : str
The variable name that holds the version string.

Raises
------
ValueError
If there was no assignment to version_variable in file_name.

Returns
-------
version_string : str
The version string parsed from file_name_name.
"""
with open(file_name) as f:
tree = ast.parse(f.read())
# Look at all assignment nodes that happen in the ast. If the variable
# name matches the given parameter, grab the value (which will be
# the version string we are looking for).
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
if node.targets[0].id == version_variable:
return node.value.s
raise ValueError(
f"Could not find an assignment to {version_variable} " f"within '{file_name}'"
)


setup(
name="git_theta_checkpoints_pytorch",
description="Plugin to support the pytorch checkpoint format.",
install_requires=[
# "git_theta",
"torch",
],
version=get_version("git_theta_checkpoints_pytorch/__init__.py"),
packages=[
"git_theta_checkpoints_pytorch",
],
author="Brian Lester",
entry_points={
"git_theta.plugins.checkpoints": [
"pytorch = git_theta_checkpoints_pytorch.checkpoints:PickledDictCheckpoint",
"pickled-dict = git_theta_checkpoints_pytorch.checkpoints:PickledDictCheckpoint",
],
"git_theta.plugins.checkpoint.sniffers": [
"pytorch = git_theta_checkpoints_pytorch.sniffer:pytorch_sniffer",
],
},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.2.0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Checkpoint using the HF safetensors format.

safetensors has the ability to write model checkpoint from "dl-native" -> "safetensors"
and read "safetensors" -> any "dl-native" framework, not just the one that was
used to write it. Therefore, we read/write with their numpy API.
"""

import safetensors.numpy
from file_or_name import file_or_name

from git_theta.checkpoints import Checkpoint


# TODO(bdlester): Can we leverage the lazying loading ability to make things faster?
class SafeTensorsCheckpoint(Checkpoint):
"""Class for r/w of the safetensors format. https://github.com/huggingface/safetensors"""

name: str = "safetensors"

@classmethod
@file_or_name(checkpoint_path="rb")
def load(cls, checkpoint_path: str):
# Note that we use the numpy as the framework because we don't care what
# their downstream dl framework is, we only want the results back as
# numpy arrays.
return safetensors.numpy.load(checkpoint_path.read())

@file_or_name(checkpoint_path="wb")
def save(self, checkpoint_path: str):
# Note, git theta uses numpy internally, so we save using the numpy api,
# regardless of the original framework they used to write the checkpoint.
checkpoint_dict = self.to_framework()
checkpoint_path.write(safetensors.numpy.save(checkpoint_dict))

def to_framework(self):
return self

@classmethod
def from_framework(cls, model_dict):
return cls(model_dict)


def safetensors_sniffer(checkpoint_path: str) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this here and in the sniffer file itself?

return checkpoint_path.endswith(".safetensors")
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Infer if a checkpoint is safetensors based.

We put this in a different file to avoid importing dl frameworks for file sniffing.
"""


def safetensors_sniffer(checkpoint_path: str) -> bool:
return checkpoint_path.endswith(".safetensors")
64 changes: 64 additions & 0 deletions plugins/checkpoints/safetensors/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Plugin to support the safetensor format."""

import ast
import os

from setuptools import setup


def get_version(file_name: str, version_variable: str = "__version__") -> str:
"""Find the version by walking the AST to avoid duplication.

Parameters
----------
file_name : str
The file we are parsing to get the version string from.
version_variable : str
The variable name that holds the version string.

Raises
------
ValueError
If there was no assignment to version_variable in file_name.

Returns
-------
version_string : str
The version string parsed from file_name_name.
"""
with open(file_name) as f:
tree = ast.parse(f.read())
# Look at all assignment nodes that happen in the ast. If the variable
# name matches the given parameter, grab the value (which will be
# the version string we are looking for).
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
if node.targets[0].id == version_variable:
return node.value.s
raise ValueError(
f"Could not find an assignment to {version_variable} " f"within '{file_name}'"
)


setup(
name="git_theta_checkpoints_safetensors",
description="Plugin to support the safetensors checkpoint format.",
install_requires=[
# "git_theta",
"safetensors",
],
version=get_version("git_theta_checkpoints_safetensors/__init__.py"),
packages=[
"git_theta_checkpoints_safetensors",
],
author="Brian Lester",
entry_points={
"git_theta.plugins.checkpoints": [
"safetensors = git_theta_checkpoints_safetensors.checkpoints:SafeTensorsCheckpoint",
"safetensors-checkpoint = git_theta_checkpoints_safetensors.checkpoints:SafeTensorsCheckpoint",
],
"git_theta.plugins.checkpoint.sniffers": [
"safetensors = git_theta_checkpoints_safetensors.sniffer:safetensors_sniffer",
],
},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.2.0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Infer if a checkpoint is tensorflow based.

We put this in a different file to avoid importing dl frameworks for file sniffing.
"""


def tensorflow_sniffer(checkpoint_path: str) -> bool:
return checkpoint_path.endswith(".tf")


# TODO: Add support for detecting saved models.
def saved_model_sniffer(checkpoint_path: str) -> bool:
# We don't support saved models yet.
return False
Loading
Loading