-
Notifications
You must be signed in to change notification settings - Fork 98
feat: add to_/from_safetensors #3685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
pfackeldey
merged 28 commits into
scikit-hep:main
from
pfackeldey:to_from_safetensors.py
Oct 22, 2025
Merged
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
d5180f6
feat: add to_/from_safetensors
pfackeldey 257006f
Merge branch 'main' into to_from_safetensors.py
pfackeldey c4345c5
style: pre-commit fixes
pre-commit-ci[bot] 1c9e370
satisfy pre-commit
pfackeldey ce4a86b
add test
pfackeldey 98fb3cc
satisfy pylint too
pfackeldey aefccf9
Update src/awkward/operations/ak_from_safetensors.py
pfackeldey 0edea75
Update src/awkward/operations/ak_from_safetensors.py
pfackeldey 19871f4
Update src/awkward/operations/ak_from_safetensors.py
pfackeldey 15338b2
Update src/awkward/operations/ak_to_safetensors.py
pfackeldey 5737374
Update src/awkward/operations/ak_from_safetensors.py
pfackeldey 818ddda
Update src/awkward/operations/ak_to_safetensors.py
pfackeldey c4a8aa5
Update src/awkward/operations/ak_from_safetensors.py
pfackeldey 1c68716
Update src/awkward/operations/ak_to_safetensors.py
pfackeldey fecc00e
Update src/awkward/operations/ak_from_safetensors.py
pfackeldey 1de11b9
Update src/awkward/operations/ak_from_safetensors.py
pfackeldey 65829b4
Update src/awkward/operations/ak_from_safetensors.py
pfackeldey 3f23cd5
Update src/awkward/operations/ak_from_safetensors.py
pfackeldey a3339b6
Update src/awkward/operations/ak_to_safetensors.py
pfackeldey a6fb568
Update src/awkward/operations/ak_from_safetensors.py
pfackeldey 895888b
Update src/awkward/operations/ak_to_safetensors.py
pfackeldey b63b72b
Update src/awkward/operations/ak_to_safetensors.py
pfackeldey c7724d3
Update src/awkward/operations/ak_to_safetensors.py
pfackeldey 72191b4
Update src/awkward/operations/ak_to_safetensors.py
pfackeldey 76246fa
address remaining comments
pfackeldey 42d8920
Merge branch 'main' into to_from_safetensors.py
pfackeldey 960c99c
make sure arrays are packed before serializing to safetensors
pfackeldey 03f2e73
use fsspec to allow remote writing and reading
pfackeldey File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| # BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import fsspec | ||
|
|
||
| import awkward as ak | ||
| from awkward._dispatch import high_level_function | ||
|
|
||
| __all__ = ("from_safetensors",) | ||
|
|
||
|
|
||
| @high_level_function() | ||
| def from_safetensors( | ||
| source, | ||
| *, | ||
| storage_options=None, | ||
| virtual=False, | ||
| # ak.from_buffers kwargs | ||
| buffer_key="{form_key}-{attribute}", | ||
| backend="cpu", | ||
| byteorder="<", | ||
| allow_noncanonical_form=False, | ||
| highlevel=True, | ||
| behavior=None, | ||
| attrs=None, | ||
| ): | ||
| """ | ||
| Args: | ||
| source (path-like): Name of the input file, file path, or | ||
| remote URL passed to [fsspec.core.url_to_fs](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.core.url_to_fs) | ||
| for remote reading. | ||
| storage_options (None or dict): Any additional options to pass to | ||
| [fsspec.core.url_to_fs](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.core.url_to_fs) | ||
| to open a remote file for reading. | ||
| virtual (bool, optional): If True, create a virtual (lazy) Awkward Array | ||
| that references buffers without materializing them. Defaults to False. | ||
| buffer_key (str, optional): Template for buffer names, with placeholders | ||
| `{form_key}` and `{attribute}`. Defaults to "{form_key}-{attribute}". | ||
| backend (str, optional): Backend identifier (e.g., "cpu"). Defaults to "cpu". | ||
| byteorder (str, optional): Byte order, "<" (little-endian, default) or ">". | ||
| allow_noncanonical_form (bool, optional): If True, normalize | ||
| safetensors forms that do not directly match Awkward. Defaults to False. | ||
| highlevel (bool, optional): If True, return a high-level ak.Array. If False, | ||
| return the low-level layout. Defaults to True. | ||
| behavior (Mapping | None, optional): Optional Awkward behavior mapping. | ||
| attrs (Mapping | None, optional): Optional metadata to attach to the array. | ||
|
|
||
| Returns: | ||
| ak.Array or ak.layout.Content: An Awkward Array (or layout) reconstructed | ||
| from the safetensors buffers. | ||
|
|
||
| Load a safetensors file as an Awkward Array. | ||
|
|
||
| Ref: https://huggingface.co/docs/safetensors/. | ||
|
|
||
| This function reads data serialized in the safetensors format and reconstructs | ||
| an Awkward Array (or low-level layout) from it. Buffers in the safetensors file | ||
| are mapped to Awkward buffers according to the `buffer_key` template, and | ||
| optional behavior or attributes can be attached to the returned array. | ||
|
|
||
| The safetensors file **must contain** `form` and `length` entries in its | ||
| metadata, which define the structure and length of the reconstructed array. | ||
|
|
||
| Example: | ||
|
|
||
| >>> import awkward as ak | ||
| >>> arr = ak.from_safetensors("out.safetensors") | ||
| >>> arr # doctest: +SKIP | ||
| <Array [[1, 2, 3], [], [4]] type='3 * var * int64'> | ||
|
|
||
| Create a virtual (lazy) array that references buffers without materializing them: | ||
|
|
||
| >>> virtual_arr = ak.from_safetensors("out.safetensors", virtual=True) | ||
| >>> virtual_arr # doctest: +SKIP | ||
| <Array [??, ??, ??] type='3 * var * int64'> | ||
|
|
||
|
|
||
| See also #ak.to_safetensors. | ||
| """ | ||
| # Implementation | ||
| return _impl( | ||
| source, | ||
| storage_options, | ||
| virtual, | ||
| buffer_key, | ||
| backend, | ||
| byteorder, | ||
| allow_noncanonical_form, | ||
| highlevel, | ||
| behavior, | ||
| attrs, | ||
| ) | ||
|
|
||
|
|
||
| def _impl( | ||
| source, | ||
| storage_options, | ||
| virtual, | ||
| buffer_key, | ||
| backend, | ||
| byteorder, | ||
| allow_noncanonical_form, | ||
| highlevel, | ||
| behavior, | ||
| attrs, | ||
| ): | ||
| try: | ||
| from safetensors import _safe_open_handle | ||
| except ImportError as err: | ||
| raise ImportError( | ||
| """to use ak.from_tensorflow, you must install the 'safetensors' package with: | ||
|
|
||
| pip install safetensors | ||
| or | ||
| conda install -c huggingface safetensors""" | ||
| ) from err | ||
|
|
||
| fs, source = fsspec.core.url_to_fs(source, **(storage_options or {})) | ||
|
|
||
| buffers = {} | ||
|
|
||
| def maybe_virtualize(x): | ||
| return (lambda: x) if virtual else x | ||
|
|
||
| with fs.open(source, "rb") as f: | ||
| with _safe_open_handle(f, framework="np") as g: | ||
| metadata = g.metadata() | ||
| for k in g.offset_keys(): | ||
| buffers[k] = maybe_virtualize(g.get_tensor(k)) | ||
|
|
||
| if "form" not in metadata or "length" not in metadata: | ||
| raise RuntimeError( | ||
| "Missing required metadata in safetensors file: 'form' and 'length' are required." | ||
| ) | ||
| form = ak.forms.from_json(metadata["form"]) | ||
| length = int(metadata["length"]) | ||
|
|
||
| # reconstruct array | ||
| return ak.ak_from_buffers._impl( | ||
| form, | ||
| length, | ||
| buffers, | ||
| buffer_key=buffer_key, | ||
| backend=backend, | ||
| byteorder=byteorder, | ||
| simplify=allow_noncanonical_form, | ||
| highlevel=highlevel, | ||
| behavior=behavior, | ||
| attrs=attrs, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| # BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import fsspec | ||
|
|
||
| import awkward as ak | ||
| from awkward._dispatch import high_level_function | ||
| from awkward._layout import HighLevelContext | ||
|
|
||
| __all__ = ("to_safetensors",) | ||
|
|
||
|
|
||
| @high_level_function() | ||
| def to_safetensors( | ||
| array, | ||
| destination, | ||
| *, | ||
| storage_options=None, | ||
| # ak.to_buffers kwargs | ||
| container=None, | ||
| buffer_key="{form_key}-{attribute}", | ||
| form_key="node{id}", | ||
| id_start=0, | ||
| backend=None, | ||
| byteorder=ak._util.native_byteorder, | ||
| ): | ||
| """ | ||
| Args: | ||
| array: An Awkward Array or array-like object to serialize. | ||
| destination (path-like): Name of the output file, file path, or | ||
| remote URL passed to [fsspec.core.url_to_fs](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.core.url_to_fs) | ||
| for remote writing. | ||
| storage_options (None or dict): Any additional options to pass to | ||
| [fsspec.core.url_to_fs](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.core.url_to_fs) | ||
| to open a remote file for writing. | ||
| container (dict, optional): Optional mapping to receive the generated buffer | ||
| bytes. If None (default), a temporary container is used and discarded | ||
| after writing. | ||
| buffer_key (str, optional): Format string for naming buffers. May include | ||
| `{form_key}` and `{attribute}` placeholders. Defaults to | ||
| `"{form_key}-{attribute}"`. | ||
| form_key (str, optional): Format string for node forms when generating buffer | ||
| keys. Typically includes `"{id}"`. Defaults to `"node{id}"`. | ||
| id_start (int, optional): Starting index for node numbering. Defaults to `0`. | ||
| backend (str | object, optional): Backend used to convert array data into | ||
| buffers. If None, the default backend is used. | ||
| byteorder (str, optional): Byte order for numeric buffers. Defaults to the | ||
| system's native byte order. | ||
|
|
||
| Returns: | ||
| None | ||
| This function writes the safetensors file to `destination`. If | ||
| `container` is provided, it will be populated with the raw buffer bytes. | ||
|
|
||
| Serialize an Awkward Array to the safetensors format and write it to `destination`. | ||
|
|
||
| Ref: https://huggingface.co/docs/safetensors/. | ||
|
|
||
| This function converts the provided Awkward Array (or array-like object) into raw | ||
| buffers via `ak.to_buffers` and stores them in the safetensors format. Buffer names | ||
| are generated from `buffer_key` and `form_key` templates, allowing downstream | ||
| compatibility or layout reuse. | ||
| The resulting safetensors file includes metadata containing the Awkward `form` and | ||
| array `length`, which are required for `ak.from_safetensors` to reconstruct the array. | ||
|
|
||
| Example: | ||
|
|
||
| >>> import awkward as ak | ||
| >>> arr = ak.Array([[1, 2, 3], [], [4]]) | ||
| >>> ak.to_safetensors(arr, "out.safetensors") | ||
|
|
||
|
|
||
| See also #ak.from_safetensors. | ||
| """ | ||
| # Implementation | ||
| return _impl( | ||
| array, | ||
| destination, | ||
| storage_options, | ||
| container, | ||
| buffer_key, | ||
| form_key, | ||
| id_start, | ||
| backend, | ||
| byteorder, | ||
| ) | ||
|
|
||
|
|
||
| def _impl( | ||
| array, | ||
| destination, | ||
| storage_options, | ||
| container, | ||
| buffer_key, | ||
| form_key, | ||
| id_start, | ||
| backend, | ||
| byteorder, | ||
| ): | ||
| try: | ||
| from safetensors.numpy import save | ||
| except ImportError as err: | ||
| raise ImportError( | ||
| """to use ak.to_safetensors, you must install the 'safetensors' package with: | ||
|
|
||
| pip install safetensors | ||
| or | ||
| conda install -c huggingface safetensors""" | ||
| ) from err | ||
|
|
||
| fs, destination = fsspec.core.url_to_fs(destination, **(storage_options or {})) | ||
|
|
||
| with HighLevelContext(behavior=None, attrs=None) as ctx: | ||
| layout = ctx.unwrap(array, allow_record=True, primitive_policy="error") | ||
|
|
||
| layout = ak.ak_to_packed._impl( | ||
| layout, | ||
| highlevel=False, # doesn't matter, but we can avoid extra wrapping/unwrapping | ||
| behavior=ctx.behavior, | ||
| attrs=ctx.attrs, | ||
| ) | ||
|
|
||
| form, length, buffers = ak.ak_to_buffers._impl( | ||
| layout, | ||
| container=container, | ||
| buffer_key=buffer_key, | ||
| form_key=form_key, | ||
| id_start=id_start, | ||
| backend=backend, | ||
| byteorder=byteorder, | ||
| ) | ||
|
|
||
| metadata = { | ||
| "form": form.to_json(), | ||
| "length": str(length), | ||
| } | ||
|
|
||
| byts = save(buffers, metadata) | ||
| # save | ||
| try: | ||
| with fs.open(destination, "wb") as f: | ||
| f.write(byts) | ||
| except Exception as err: | ||
| raise RuntimeError( | ||
| f"Failed to write safetensors file to '{destination}': {err}" | ||
| ) from err |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| # BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE | ||
| # | ||
| from __future__ import annotations | ||
|
|
||
| import os | ||
|
|
||
| import pytest | ||
|
|
||
| safetensors = pytest.importorskip("safetensors") | ||
|
|
||
|
|
||
| def test_roundtrip(): | ||
| import awkward as ak | ||
|
|
||
| array = ak.Array([[1, 2, 3], [], [4, 5], [6], [7, 8, 9, 10]]) | ||
|
|
||
| path = "./test.safetensors" | ||
| ak.to_safetensors(array, path) | ||
|
|
||
| loaded = ak.from_safetensors(path) | ||
| virtual_loaded = ak.from_safetensors(path, virtual=True) | ||
|
|
||
| os.remove(path) | ||
|
|
||
| assert array.layout.is_equal_to(loaded.layout, all_parameters=True) | ||
| assert array.layout.is_equal_to( | ||
| virtual_loaded.layout.materialize(), all_parameters=True | ||
| ) | ||
|
|
||
|
|
||
| def test_virtual_array_to_safetensors(): | ||
| import awkward as ak | ||
|
|
||
| array = ak.Array([[1, 2, 3], [], [4, 5], [6], [7, 8, 9, 10]]) | ||
|
|
||
| path = "./test_virtual{}.safetensors".format | ||
|
|
||
| ak.to_safetensors(array, path(0)) | ||
| virtual_loaded = ak.from_safetensors(path(0), virtual=True) | ||
|
|
||
| ak.to_safetensors(virtual_loaded, path(1)) | ||
| loaded = ak.from_safetensors(path(1), virtual=False) | ||
|
|
||
| os.remove(path(0)) | ||
| os.remove(path(1)) | ||
|
|
||
| assert virtual_loaded.layout.materialize().is_equal_to( | ||
| loaded.layout, all_parameters=True | ||
| ) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.