-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement the ability to infer the checkpoint type #254
base: main
Are you sure you want to change the base?
Conversation
Currently this is just based on the file path. In order to make the sniffing extensible, it is implemented as a plugin. In order to do that, and to avoid having to import a deep learning framework to do sniffing, the sniffers are separated out into their own files. This lead to the question of why are checkpoint plugs stored in the main repo? If we move them to the actual plugins area we can continue to support them while not needing to install them if a framework isn't going to be used. Now when you want to git theta with a specific framework you can either run `pip install git-theta[framework]` or run `pip install git-theta-checkpoints-framework` for your framework.
|
||
|
||
def sniff_checkpoint(checkpoint_path) -> str: | ||
"""En""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's En?
) | ||
for ckpt_type, ckpt_sniffer in loaded_plugins.items(): | ||
logger.debug(f"Checking if {checkpoint_path} is a {ckpt_type} checkpoint.") | ||
if ckpt_sniffer(checkpoint_path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we have any kind of error handling in the case where two checkpoint sniffers denote that a given checkpoint is their type? Currently this will return the checkpoint type corresponding to the first checkpoint sniffer the loop encounters that matches (which I assume is in arbitrary order).
@@ -0,0 +1,9 @@ | |||
"""Infer if a checkpoint is flax based. | |||
|
|||
We put this in a different file to avoid importing dl frameworks for file sniffing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a relevant comment for this particular sniffer?
from setuptools import setup | ||
|
||
|
||
def get_version(file_name: str, version_variable: str = "__version__") -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not seem to be specific to the flax checkpoint sniffer and is duplicated across the setup.py files - should it be pulled out into a more general utility function filea
return cls(model_dict) | ||
|
||
|
||
def safetensors_sniffer(checkpoint_path: str) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this here and in the sniffer file itself?
Currently the included sniffers are just based on the file path, but some formats, for example pytorch's pickle and newer zip format, have magic numbers in their headers we can check for in the future.
In order to make the sniffing extensible, it is implemented as a plugin. In order to do that, and to avoid having to import a deep learning framework to do sniffing, the sniffers are separated out into their own files. This lead to the question of why are checkpoint plugs stored in the main repo? If we move them to the actual plugins area we can continue to support them while not needing to install them if a framework isn't going to be used.
An alternative implementation could include the sniffer plugin as part of the main repo, but the usage would be a bit clunkier as we would need to handle import errors (for example a pytorch sniffer might want to use torch but it isn't installed). The current solution would only sniff for checkpoints from frameworks that you have installed.
From a user perspective, this PR results in the following:
Currently, to use git-theta with a framework, say tensorflow, you need to have tensorflow installed in your python environment and you need to tell git theta that a checkpoint is TF via an environment variable.
pip install git-theta[tensorflow]
is provided as an easy way to ensure both git theta and tensorflow are installed, but it doesn't need to be used.With this version, to use some framework, say tensorflow, you either need to install
git-theta-checkpoints-tensorflow
along withgit-theta
or you can usepip install git-theta[tensorflow]
as a short cut. If later you want to use git-theta with pytorch, you need to installgit-theta-checkpoint-pytorch
.I'd like some feedback on if people think this change to how things are installed would be too much for users.