Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the ability to infer the checkpoint type #254

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

blester125
Copy link
Collaborator

Currently the included sniffers are just based on the file path, but some formats, for example pytorch's pickle and newer zip format, have magic numbers in their headers we can check for in the future.

In order to make the sniffing extensible, it is implemented as a plugin. In order to do that, and to avoid having to import a deep learning framework to do sniffing, the sniffers are separated out into their own files. This lead to the question of why are checkpoint plugs stored in the main repo? If we move them to the actual plugins area we can continue to support them while not needing to install them if a framework isn't going to be used.

An alternative implementation could include the sniffer plugin as part of the main repo, but the usage would be a bit clunkier as we would need to handle import errors (for example a pytorch sniffer might want to use torch but it isn't installed). The current solution would only sniff for checkpoints from frameworks that you have installed.

From a user perspective, this PR results in the following:

Currently, to use git-theta with a framework, say tensorflow, you need to have tensorflow installed in your python environment and you need to tell git theta that a checkpoint is TF via an environment variable. pip install git-theta[tensorflow] is provided as an easy way to ensure both git theta and tensorflow are installed, but it doesn't need to be used.

With this version, to use some framework, say tensorflow, you either need to install git-theta-checkpoints-tensorflow along with git-theta or you can use pip install git-theta[tensorflow] as a short cut. If later you want to use git-theta with pytorch, you need to install git-theta-checkpoint-pytorch.

I'd like some feedback on if people think this change to how things are installed would be too much for users.

Currently this is just based on the file path. In order to make the
sniffing extensible, it is implemented as a plugin. In order to do that,
and to avoid having to import a deep learning framework to do sniffing,
the sniffers are separated out into their own files. This lead to the
question of why are checkpoint plugs stored in the main repo? If we move
them to the actual plugins area we can continue to support them while
not needing to install them if a framework isn't going to be used.

Now when you want to git theta with a specific framework you can either
run `pip install git-theta[framework]` or run `pip install
git-theta-checkpoints-framework` for your framework.
@blester125 blester125 requested review from craffel and nkandpa2 May 26, 2024 20:00


def sniff_checkpoint(checkpoint_path) -> str:
"""En"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's En?

)
for ckpt_type, ckpt_sniffer in loaded_plugins.items():
logger.debug(f"Checking if {checkpoint_path} is a {ckpt_type} checkpoint.")
if ckpt_sniffer(checkpoint_path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have any kind of error handling in the case where two checkpoint sniffers denote that a given checkpoint is their type? Currently this will return the checkpoint type corresponding to the first checkpoint sniffer the loop encounters that matches (which I assume is in arbitrary order).

@@ -0,0 +1,9 @@
"""Infer if a checkpoint is flax based.

We put this in a different file to avoid importing dl frameworks for file sniffing.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a relevant comment for this particular sniffer?

from setuptools import setup


def get_version(file_name: str, version_variable: str = "__version__") -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not seem to be specific to the flax checkpoint sniffer and is duplicated across the setup.py files - should it be pulled out into a more general utility function filea

return cls(model_dict)


def safetensors_sniffer(checkpoint_path: str) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this here and in the sniffer file itself?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants