Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added slice argument for load_hdf5 #1753

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 78 additions & 3 deletions .github/rd-release-config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,6 @@ autolabeler:
- label: 'interoperability'
title:
- '/Support.+/'
- label: 'testing'
files:
- '**/tests/**/*'
- label: 'classification'
files:
- 'heat/classification/**/*'
Expand Down Expand Up @@ -164,6 +161,84 @@ autolabeler:
- label: 'linalg'
files:
- 'heat/core/linalg/**/*'
- label: 'arithmetics'
files:
- 'heat/core/arithmetics.py'
- label: 'base'
files:
- 'heat/core/base.py'
- label: 'communication'
files:
- 'heat/core/communication.py'
- label: 'complex_math'
files:
- 'heat/core/complex_math.py'
- label: 'constants'
files:
- 'heat/core/constants.py'
- label: 'devices'
files:
- 'heat/core/devices.py'
- label: 'dndarray'
files:
- 'heat/core/dndarray.py'
- label: 'exponential'
files:
- 'heat/core/exponential.py'
- label: 'indexing'
files:
- 'heat/core/indexing.py'
- label: 'io'
files:
- 'heat/core/io.py'
- label: 'logical'
files:
- 'heat/core/logical.py'
- label: 'manipulations'
files:
- 'heat/core/manipulations.py'
- label: 'memory'
files:
- 'heat/core/memory.py'
- label: 'printing'
files:
- 'heat/core/printing.py'
- label: 'random'
files:
- 'heat/core/random.py'
- label: 'relational'
files:
- 'heat/core/relational.py'
- label: 'rounding'
files:
- 'heat/core/rounding.py'
- label: 'santiation'
files:
- 'heat/core/sanitation.py'
- label: 'signal'
files:
- 'heat/core/signal.py'
- label: 'statistics'
files:
- 'heat/core/statistics.py'
- label: 'stride_tricks'
files:
- 'heat/core/stride_tricks.py'
- label: 'tiling'
files:
- 'heat/core/tiling.py'
- label: 'trigonometrics'
files:
- 'heat/core/trigonometrics.py'
- label: 'types'
files:
- 'heat/core/types.py'
- label: 'version'
files:
- 'heat/core/version.py'
- label: 'vmap'
files:
- 'heat/core/vmap.py'

change-template: '- #$NUMBER $TITLE (by @$AUTHOR)'
category-template: '### $TITLE'
Expand Down
70 changes: 53 additions & 17 deletions heat/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,29 @@
"load_npy_from_path",
]


def size_from_slice(size: int, s: slice) -> Tuple[int, int]:
"""
Determines the size of a slice object.

Parameters
----------
size: int
The size of the array the slice object is applied to.
s : slice
The slice object to determine the size of.

Returns
-------
int
The size of the sliced object.
int
The start index of the slice object.
"""
new_range = range(size)[s]
return len(new_range), new_range.start if len(new_range) > 0 else 0


try:
import netCDF4 as nc
except ImportError:
Expand Down Expand Up @@ -489,7 +512,7 @@ def load_hdf5(
path: str,
dataset: str,
dtype: datatype = types.float32,
load_fraction: float = 1.0,
slices: Optional[Tuple[slice]] = None,
split: Optional[int] = None,
device: Optional[str] = None,
comm: Optional[Communication] = None,
Expand All @@ -505,10 +528,8 @@ def load_hdf5(
Name of the dataset to be read.
dtype : datatype, optional
Data type of the resulting array.
load_fraction : float between 0. (excluded) and 1. (included), default is 1.
if 1. (default), the whole dataset is loaded from the file specified in path
else, the dataset is loaded partially, with the fraction of the dataset (along the split axis) specified by load_fraction
If split is None, load_fraction is automatically set to 1., i.e. the whole dataset is loaded.
slices : tuple of slice objects, optional
Load only the specified slices of the dataset.
split : int or None, optional
The axis along which the data is distributed among the processing cores.
device : str, optional
Expand Down Expand Up @@ -545,14 +566,6 @@ def load_hdf5(
elif split is not None and not isinstance(split, int):
raise TypeError(f"split must be None or int, not {type(split)}")

if not isinstance(load_fraction, float):
raise TypeError(f"load_fraction must be float, but is {type(load_fraction)}")
else:
if split is not None and (load_fraction <= 0.0 or load_fraction > 1.0):
raise ValueError(
f"load_fraction must be between 0. (excluded) and 1. (included), but is {load_fraction}."
)

# infer the type and communicator for the loaded array
dtype = types.canonical_heat_type(dtype)
# determine the comm and device the data will be placed on
Expand All @@ -563,13 +576,36 @@ def load_hdf5(
with h5py.File(path, "r") as handle:
data = handle[dataset]
gshape = data.shape
if split is not None:
gshape = list(gshape)
gshape[split] = int(gshape[split] * load_fraction)
gshape = tuple(gshape)
new_gshape = tuple()
offsets = [0] * len(gshape)
if slices is not None:
if len(slices) != len(gshape):
raise ValueError(
f"Number of slices ({len(slices)}) does not match the number of dimensions ({len(gshape)})"
)
for i, s in enumerate(slices):
if s:
if s.step is not None and s.step != 1:
raise ValueError("Slices with step != 1 are not supported")
new_axis_size, offset = size_from_slice(gshape[i], s)
new_gshape += (new_axis_size,)
offsets[i] = offset
else:
new_gshape += (gshape[i],)
offsets[i] = 0

gshape = new_gshape

dims = len(gshape)
split = sanitize_axis(gshape, split)
_, _, indices = comm.chunk(gshape, split)

if slices is not None:
new_indices = tuple()
for offset, index in zip(offsets, indices):
new_indices += (slice(index.start + offset, index.stop + offset),)
indices = new_indices

balanced = True
if split is None:
data = torch.tensor(
Expand Down
73 changes: 65 additions & 8 deletions heat/core/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,31 @@ def tearDown(self):
# synchronize all nodes
ht.MPI_WORLD.Barrier()

def test_size_from_slice(self):
test_cases = [
(1000, slice(500)),
(10, slice(0, 10, 2)),
(100, slice(0, 100, 10)),
(1000, slice(0, 1000, 100)),
(0, slice(0)),
]
for size, slice_obj in test_cases:
with self.subTest(size=size, slice=slice_obj):
expected_sequence = list(range(size))[slice_obj]
if len(expected_sequence) == 0:
expected_offset = 0
else:
expected_offset = expected_sequence[0]

expected_new_size = len(expected_sequence)

new_size, offset = ht.io.size_from_slice(size, slice_obj)
print(f"Expected sequence: {expected_sequence}")
print(f"Expected new size: {expected_new_size}, new size: {new_size}")
print(f"Expected offset: {expected_offset}, offset: {offset}")
self.assertEqual(expected_new_size, new_size)
self.assertEqual(expected_offset, offset)

# catch-all loading
def test_load(self):
# HDF5
Expand Down Expand Up @@ -541,10 +566,6 @@ def test_load_hdf5(self):
self.assertEqual(iris.larray.dtype, torch.float32)
self.assertTrue((self.IRIS == iris.larray).all())

# cropped load
iris_cropped = ht.load_hdf5(self.HDF5_PATH, self.HDF5_DATASET, split=0, load_fraction=0.5)
self.assertEqual(iris_cropped.shape[0], iris.shape[0] // 2)

# positive split axis
iris = ht.load_hdf5(self.HDF5_PATH, self.HDF5_DATASET, split=0)
self.assertIsInstance(iris, ht.DNDarray)
Expand Down Expand Up @@ -582,10 +603,6 @@ def test_load_hdf5_exception(self):
ht.load_hdf5("iris.h5", 1)
with self.assertRaises(TypeError):
ht.load_hdf5("iris.h5", dataset="data", split=1.0)
with self.assertRaises(TypeError):
ht.load_hdf5(self.HDF5_PATH, self.HDF5_DATASET, load_fraction="a")
with self.assertRaises(ValueError):
ht.load_hdf5(self.HDF5_PATH, self.HDF5_DATASET, load_fraction=0.0, split=0)

# file or dataset does not exist
with self.assertRaises(IOError):
Expand Down Expand Up @@ -892,3 +909,43 @@ def test_load_multiple_csv_exception(self):
ht.MPI_WORLD.Barrier()
if ht.MPI_WORLD.rank == 0:
shutil.rmtree(os.path.join(os.getcwd(), "heat/datasets/csv_tests"))

@unittest.skipIf(not ht.io.supports_hdf5(), reason="Requires HDF5")
def test_load_partial_hdf5(self):
test_axis = [None, 0, 1]
test_slices = [
(slice(0, 50, None), slice(None, None, None)),
(slice(0, 50, None), slice(0, 2, None)),
(slice(50, 100, None), slice(None, None, None)),
(slice(None, None, None), slice(2, 4, None)),
(slice(50), None),
(None, slice(0, 3, 2)),
]
test_cases = [(a, s) for a in test_axis for s in test_slices]

for axis, slices in test_cases:
with self.subTest(axis=axis, slices=slices):
print("axis: ", axis)
print("slices: ", slices)
HDF5_PATH = os.path.join(os.getcwd(), "heat/datasets/iris.h5")
HDF5_DATASET = "data"
expect_error = False
for s in slices:
if s and s.step not in [None, 1]:
expect_error = True
break

if expect_error:
with self.assertRaises(ValueError):
sliced_iris = ht.load_hdf5(
HDF5_PATH, HDF5_DATASET, split=axis, slices=slices
)
else:
original_iris = ht.load_hdf5(HDF5_PATH, HDF5_DATASET, split=axis)
tmp_slices = tuple(slice(None) if s is None else s for s in slices)
expected_iris = original_iris[tmp_slices]
sliced_iris = ht.load_hdf5(HDF5_PATH, HDF5_DATASET, split=axis, slices=slices)
print("Original shape: " + str(original_iris.shape))
print("Sliced shape: " + str(sliced_iris.shape))
print("Expected shape: " + str(expected_iris.shape))
self.assertTrue(ht.equal(sliced_iris, expected_iris))
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"netcdf": ["netCDF4>=1.5.6"],
"dev": ["pre-commit>=1.18.3"],
"examples": ["scikit-learn>=0.24.0", "matplotlib>=3.1.0"],
"cb": ["perun>=0.2.0"],
"cb": ["perun>=0.8"],
"pandas": ["pandas>=1.4"],
},
)
Loading