|
| 1 | +"""Plugin to support the safetensor format.""" |
| 2 | + |
| 3 | +import ast |
| 4 | +import os |
| 5 | + |
| 6 | +from setuptools import setup |
| 7 | + |
| 8 | + |
| 9 | +def get_version(file_name: str, version_variable: str = "__version__") -> str: |
| 10 | + """Find the version by walking the AST to avoid duplication. |
| 11 | +
|
| 12 | + Parameters |
| 13 | + ---------- |
| 14 | + file_name : str |
| 15 | + The file we are parsing to get the version string from. |
| 16 | + version_variable : str |
| 17 | + The variable name that holds the version string. |
| 18 | +
|
| 19 | + Raises |
| 20 | + ------ |
| 21 | + ValueError |
| 22 | + If there was no assignment to version_variable in file_name. |
| 23 | +
|
| 24 | + Returns |
| 25 | + ------- |
| 26 | + version_string : str |
| 27 | + The version string parsed from file_name_name. |
| 28 | + """ |
| 29 | + with open(file_name) as f: |
| 30 | + tree = ast.parse(f.read()) |
| 31 | + # Look at all assignment nodes that happen in the ast. If the variable |
| 32 | + # name matches the given parameter, grab the value (which will be |
| 33 | + # the version string we are looking for). |
| 34 | + for node in ast.walk(tree): |
| 35 | + if isinstance(node, ast.Assign): |
| 36 | + if node.targets[0].id == version_variable: |
| 37 | + return node.value.s |
| 38 | + raise ValueError( |
| 39 | + f"Could not find an assignment to {version_variable} " f"within '{file_name}'" |
| 40 | + ) |
| 41 | + |
| 42 | + |
| 43 | +setup( |
| 44 | + name="git_theta_checkpoints_safetensors", |
| 45 | + description="Plugin to support the safetensors checkpoint format.", |
| 46 | + install_requires=[ |
| 47 | + # "git_theta", |
| 48 | + "safetensors", |
| 49 | + ], |
| 50 | + version=get_version("git_theta_checkpoints_safetensors/__init__.py"), |
| 51 | + packages=[ |
| 52 | + "git_theta_checkpoints_safetensors", |
| 53 | + ], |
| 54 | + author="Brian Lester", |
| 55 | + entry_points={ |
| 56 | + "git_theta.plugins.checkpoints": [ |
| 57 | + "safetensors = git_theta_checkpoints_safetensors.checkpoints:SafeTensorsCheckpoint", |
| 58 | + "safetensors-checkpoint = git_theta_checkpoints_safetensors.checkpoints:SafeTensorsCheckpoint", |
| 59 | + ], |
| 60 | + "git_theta.plugins.checkpoint.sniffers": [ |
| 61 | + "safetensors = git_theta_checkpoints_safetensors.sniffer:safetensors_sniffer", |
| 62 | + ], |
| 63 | + }, |
| 64 | +) |
0 commit comments