From f297ea220bde9eb997222729bf59d85a77810d8c Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Mon, 17 May 2021 16:18:33 +0200 Subject: [PATCH 01/17] added partition interface to DNDarray --- heat/core/dndarray.py | 100 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index e5485f1472..5d27d3a3ff 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -58,6 +58,8 @@ class DNDarray: Describes whether the data are evenly distributed across processes. If this information is not available (``self.balanced is None``), it can be gathered via the :func:`is_distributed()` method (requires communication). + partition_interface: bool, Optional + If true, a partition interface will be created """ def __init__( @@ -80,6 +82,7 @@ def __init__( self.__ishalo = False self.__halo_next = None self.__halo_prev = None + self.__partition_interface__ = None # check for inconsistencies between torch and heat devices assert str(array.device) == device.torch_device @@ -195,6 +198,24 @@ def ndim(self) -> int: """ return len(self.__gshape) + @property + def partition_interface(self) -> dict: + """ + This will return a dictionary containing information useful for working with the partitioned + data. These items include the shape of the data on each process, the starting index of the data + that a process has, the datatype of the data, the local devices, as well as the global + partitioning scheme. + + An example of the output and shape is shown in :func:`ht.core.DNDarray.create_partition_interface `. + + Returns + ------- + dictionary with the partition interface + """ + if self.__partition_interface__ is None: + self.__partition_interface__ = self.create_partition_interface() + return self.__partition_interface__ + @property def size(self) -> int: """ @@ -573,6 +594,85 @@ def create_lshape_map(self) -> torch.Tensor: self.comm.Allreduce(MPI.IN_PLACE, lshape_map, MPI.SUM) return lshape_map + def create_partition_interface(self): + """ + Create a partition interface in line with the DPPY proposal. This is subject to change. + The intention of this to facilitate the usage of a general format for the referencing of + distributed datasets. + + An example of the output and shape is shown below. + + __partitioned_interface__ = { + 'shape': (27, 3, 2), + 'partition_tiling': (4, 1, 1), + 'partitions': { + (0, 0, 0): { + 'start': (0, 0, 0), + 'shape': (7, 3, 2), + 'data': tensor([...], dtype=torch.int32), + 'location': 0, + 'dtype': torch.int32, + 'device': 'cpu' + }, + (1, 0, 0): { + 'start': (7, 0, 0), + 'shape': (7, 3, 2), + 'data': None, + 'location': 1, + 'dtype': torch.int32, + 'device': 'cpu' + }, + (2, 0, 0): { + 'start': (14, 0, 0), + 'shape': (7, 3, 2), + 'data': None, + 'location': 2, + 'dtype': torch.int32, + 'device': 'cpu' + }, + (3, 0, 0): { + 'start': (21, 0, 0), + 'shape': (6, 3, 2), + 'data': None, + 'location': 3, + 'dtype': torch.int32, + 'device': 'cpu' + } + }, + } + + Returns + ------- + dictionary containing the partition interface as shown above. + """ + lshape_map = self.create_lshape_map() + start_idx_map = torch.zeros_like(lshape_map) + z = torch.tensor([0], device=self.device.torch_device, dtype=self.dtype.torch_type()) + starts = torch.cat((z, torch.cumsum(lshape_map[:, self.split], dim=0)[:-1]), dim=0) + start_idx_map[:, self.split] = starts + part_tiling = [1] * self.ndim + part_tiling[self.split] = self.comm.size + + partitions = {} + base_key = [0] * self.ndim + for r in range(self.comm.size): + base_key[self.split] = r + partitions[tuple(base_key)] = { + "start": tuple(start_idx_map[r].tolist()), + "shape": tuple(lshape_map[r].tolist()), + "data": None if r != self.comm.rank else self.larray, + "location": r, + "dtype": self.dtype.torch_type(), + "device": self.device.torch_device, + } + + partition_dict = { + "shape": self.gshape, + "partition_tiling": tuple(part_tiling), + "partitions": partitions, + } + return partition_dict + def __float__(self) -> DNDarray: """ Float scalar casting. From 919159952bf5c788c937ae357f649573984ccf1e Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Wed, 19 May 2021 16:11:03 +0200 Subject: [PATCH 02/17] added 'locals' key to partition interface --- heat/core/dndarray.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 5d27d3a3ff..76b0be59dc 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -639,6 +639,7 @@ def create_partition_interface(self): 'device': 'cpu' } }, + 'locals': [(rank, 0, 0)], } Returns @@ -666,10 +667,13 @@ def create_partition_interface(self): "device": self.device.torch_device, } + locals = [0] * self.ndim + locals[self.split] = self.comm.rank partition_dict = { "shape": self.gshape, "partition_tiling": tuple(part_tiling), "partitions": partitions, + "locals": locals, } return partition_dict From 89fda67a0c84fa3d5b74251eb1a141a6d5d33764 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Wed, 19 May 2021 16:11:41 +0200 Subject: [PATCH 03/17] renamed locals to lcls to avoid global name --- heat/core/dndarray.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 76b0be59dc..01898684db 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -667,13 +667,13 @@ def create_partition_interface(self): "device": self.device.torch_device, } - locals = [0] * self.ndim - locals[self.split] = self.comm.rank + lcls = [0] * self.ndim + lcls[self.split] = self.comm.rank partition_dict = { "shape": self.gshape, "partition_tiling": tuple(part_tiling), "partitions": partitions, - "locals": locals, + "locals": lcls, } return partition_dict From 8806b7c30e59369131b0eee8b473f8961598ca52 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Wed, 19 May 2021 16:39:16 +0200 Subject: [PATCH 04/17] corrected format of locals --- heat/core/dndarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 01898684db..92334fce86 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -673,7 +673,7 @@ def create_partition_interface(self): "shape": self.gshape, "partition_tiling": tuple(part_tiling), "partitions": partitions, - "locals": lcls, + "locals": [tuple(lcls)], } return partition_dict From 51368b70bbc83e49b53ef350852ac63868022eca Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Wed, 19 May 2021 16:54:55 +0200 Subject: [PATCH 05/17] renamed dunder class attr of DNDarray to __partitioned__ --- heat/core/dndarray.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 92334fce86..0c425fd651 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -58,8 +58,8 @@ class DNDarray: Describes whether the data are evenly distributed across processes. If this information is not available (``self.balanced is None``), it can be gathered via the :func:`is_distributed()` method (requires communication). - partition_interface: bool, Optional - If true, a partition interface will be created + # generate_partitioned: bool, Optional + # If true, a partition interface will be created """ def __init__( @@ -82,7 +82,7 @@ def __init__( self.__ishalo = False self.__halo_next = None self.__halo_prev = None - self.__partition_interface__ = None + self.__partitioned__ = None # check for inconsistencies between torch and heat devices assert str(array.device) == device.torch_device From 81e47c8a1f7df1a721efbc6952eff98c03df226f Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Wed, 19 May 2021 17:09:07 +0200 Subject: [PATCH 06/17] corrected split=0 case, corrected DNDarray property to be 'partitioned' --- heat/core/dndarray.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 0c425fd651..b1d21b6000 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -199,7 +199,7 @@ def ndim(self) -> int: return len(self.__gshape) @property - def partition_interface(self) -> dict: + def partitioned(self) -> dict: """ This will return a dictionary containing information useful for working with the partitioned data. These items include the shape of the data on each process, the starting index of the data @@ -212,9 +212,9 @@ def partition_interface(self) -> dict: ------- dictionary with the partition interface """ - if self.__partition_interface__ is None: - self.__partition_interface__ = self.create_partition_interface() - return self.__partition_interface__ + if self.__partitioned__ is None: + self.__partitioned__ = self.create_partition_interface() + return self.__partitioned__ @property def size(self) -> int: @@ -646,29 +646,41 @@ def create_partition_interface(self): ------- dictionary containing the partition interface as shown above. """ + # sp = lshape_map = self.create_lshape_map() start_idx_map = torch.zeros_like(lshape_map) + + part_tiling = [1] * self.ndim + lcls = [0] * self.ndim + z = torch.tensor([0], device=self.device.torch_device, dtype=self.dtype.torch_type()) - starts = torch.cat((z, torch.cumsum(lshape_map[:, self.split], dim=0)[:-1]), dim=0) + if self.split is not None: + starts = torch.cat((z, torch.cumsum(lshape_map[:, self.split], dim=0)[:-1]), dim=0) + lcls[self.split] = self.comm.rank + part_tiling[self.split] = self.comm.size + else: + starts = torch.zeros(self.ndim, dtype=torch.int, device=self.device.torch_device) + start_idx_map[:, self.split] = starts - part_tiling = [1] * self.ndim - part_tiling[self.split] = self.comm.size partitions = {} base_key = [0] * self.ndim for r in range(self.comm.size): - base_key[self.split] = r + if self.split is not None: + base_key[self.split] = r + dat = None if r != self.comm.rank else self.larray + else: + dat = self.larray + partitions[tuple(base_key)] = { "start": tuple(start_idx_map[r].tolist()), "shape": tuple(lshape_map[r].tolist()), - "data": None if r != self.comm.rank else self.larray, + "data": dat, "location": r, "dtype": self.dtype.torch_type(), "device": self.device.torch_device, } - lcls = [0] * self.ndim - lcls[self.split] = self.comm.rank partition_dict = { "shape": self.gshape, "partition_tiling": tuple(part_tiling), From 857f585e97fb181602c9c157db3968f7b1dd22f0 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Wed, 19 May 2021 17:15:40 +0200 Subject: [PATCH 07/17] DNDarray.__partitioned__ -> __partitions_dict__, DNDarray.partitioned -> __partitioned__ --- heat/core/dndarray.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index b1d21b6000..4cf0025720 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -82,7 +82,7 @@ def __init__( self.__ishalo = False self.__halo_next = None self.__halo_prev = None - self.__partitioned__ = None + self.__partitions_dict__ = None # check for inconsistencies between torch and heat devices assert str(array.device) == device.torch_device @@ -199,7 +199,7 @@ def ndim(self) -> int: return len(self.__gshape) @property - def partitioned(self) -> dict: + def __partitioned__(self) -> dict: """ This will return a dictionary containing information useful for working with the partitioned data. These items include the shape of the data on each process, the starting index of the data @@ -212,9 +212,9 @@ def partitioned(self) -> dict: ------- dictionary with the partition interface """ - if self.__partitioned__ is None: - self.__partitioned__ = self.create_partition_interface() - return self.__partitioned__ + if self.__partitions_dict__ is None: + self.__partitions_dict__ = self.create_partition_interface() + return self.__partitions_dict__ @property def size(self) -> int: From d69fdd6a3a0761e948ca767545e2c8cd08ce912d Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 22 Jun 2021 11:55:51 +0200 Subject: [PATCH 08/17] added tests for partitioned attribute --- heat/core/tests/test_dndarray.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 17950daff4..44ab6609b6 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -738,6 +738,15 @@ def test_or(self): ht.equal(int16_tensor | int16_vector, ht.bitwise_or(int16_tensor, int16_vector)) ) + def test_partitioned(self): + a = ht.zeros((120, 120), split=0) + parted = a.__partitioned__ + self.assertEqual(parted["shape"], (120, 120)) + self.assertEqual(parted["partition_tiling"], (a.comm.size, 1)) + self.assertEqual(parted["shape"], (120, 120)) + if a.comm.rank == 0: + self.assertEqual(parted["partitions"][(0, 0)]["start"], (0, 0)) + def test_redistribute(self): # need to test with 1, 2, 3, and 4 dims st = ht.zeros((50,), split=0) From 0afc772619d6d9c9f1c82fcd17373dc9de33b77a Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 22 Jun 2021 12:17:39 +0200 Subject: [PATCH 09/17] minor changes to test cases to check that things after the resplit are taken care of --- heat/core/dndarray.py | 5 ++++- heat/core/tests/test_dndarray.py | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 483ca79b71..457a816c0c 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -620,7 +620,7 @@ def create_partition_interface(self): An example of the output and shape is shown below. - __partitioned_interface__ = { + __partitioned__ = { 'shape': (27, 3, 2), 'partition_tiling': (4, 1, 1), 'partitions': { @@ -1353,6 +1353,9 @@ def resplit_(self, axis: int = None): # early out for unchanged content if axis == self.split: return self + + self.__partitions_dict__ = None + if axis is None: gathered = torch.empty( self.shape, dtype=self.dtype.torch_type(), device=self.device.torch_device diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 44ab6609b6..c90a4777af 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -744,8 +744,10 @@ def test_partitioned(self): self.assertEqual(parted["shape"], (120, 120)) self.assertEqual(parted["partition_tiling"], (a.comm.size, 1)) self.assertEqual(parted["shape"], (120, 120)) - if a.comm.rank == 0: - self.assertEqual(parted["partitions"][(0, 0)]["start"], (0, 0)) + self.assertEqual(parted["partitions"][(0, 0)]["start"], (0, 0)) + + a.resplit_(1) + self.assertIsNone(a.__partitions_dict__) def test_redistribute(self): # need to test with 1, 2, 3, and 4 dims From 750cc2bb94e579ee81a6bcab94646bfad7d28a86 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 22 Jun 2021 12:21:54 +0200 Subject: [PATCH 10/17] split=None tests --- heat/core/tests/test_dndarray.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index c90a4777af..491a04a728 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -743,11 +743,14 @@ def test_partitioned(self): parted = a.__partitioned__ self.assertEqual(parted["shape"], (120, 120)) self.assertEqual(parted["partition_tiling"], (a.comm.size, 1)) - self.assertEqual(parted["shape"], (120, 120)) self.assertEqual(parted["partitions"][(0, 0)]["start"], (0, 0)) - a.resplit_(1) + a.resplit_(None) self.assertIsNone(a.__partitions_dict__) + parted = a.__partitioned__ + self.assertEqual(parted["shape"], (120, 120)) + self.assertEqual(parted["partition_tiling"], (1, 1)) + self.assertEqual(parted["partitions"][(0, 0)]["start"], (0, 0)) def test_redistribute(self): # need to test with 1, 2, 3, and 4 dims From 992c3853fe567b11721b0a8e591af28f428eafe5 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 13 Jul 2021 10:51:18 +0200 Subject: [PATCH 11/17] changelog update --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 107c153cb8..0498c6f4d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,6 +65,7 @@ Example on 2 processes: ### Misc. - [#761](https://github.com/helmholtz-analytics/heat/pull/761) New feature: `result_type` +- [#788](https://github.com/helmholtz-analytics/heat/pull/788) Added the partition interface `DNDarray` for use with DPPY - [#794](https://github.com/helmholtz-analytics/heat/pull/794) New feature: `meshgrid` - [#821](https://github.com/helmholtz-analytics/heat/pull/821) Enhancement: it is no longer necessary to load-balance an imbalanced `DNDarray` before gathering it onto all processes. In short: `ht.resplit(array, None)` now works on imbalanced arrays as well. From 012318d091f531c0d24854252a022e04655256e6 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 13 Jul 2021 14:31:34 +0200 Subject: [PATCH 12/17] added 'get' attributed to __partitioned__ to get a tile from a DNDarray --- heat/core/dndarray.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index b5b023cd7c..bfacd07b95 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -664,7 +664,6 @@ def create_partition_interface(self): ------- dictionary containing the partition interface as shown above. """ - # sp = lshape_map = self.create_lshape_map() start_idx_map = torch.zeros_like(lshape_map) @@ -705,6 +704,12 @@ def create_partition_interface(self): "partitions": partitions, "locals": [tuple(lcls)], } + + def _partition_getter(key): + return partition_dict["partitions"][key]["data"] + + partition_dict["get"] = _partition_getter + return partition_dict def __float__(self) -> DNDarray: From 26203c0db52b0093a7d42f89090284c97454299d Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 13 Jul 2021 14:40:29 +0200 Subject: [PATCH 13/17] reduced level of abstraction for __partitioned__['get'] --- heat/core/dndarray.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index bfacd07b95..44920f2145 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -703,13 +703,9 @@ def create_partition_interface(self): "partition_tiling": tuple(part_tiling), "partitions": partitions, "locals": [tuple(lcls)], + "get": lambda x: x, } - def _partition_getter(key): - return partition_dict["partitions"][key]["data"] - - partition_dict["get"] = _partition_getter - return partition_dict def __float__(self) -> DNDarray: From 29d038504b0ef73c93cff2c4962744a28fce5aea Mon Sep 17 00:00:00 2001 From: Frank Schlimbach Date: Thu, 26 Aug 2021 08:09:45 -0500 Subject: [PATCH 14/17] adding from_partitioned; aligning __partitioned__ with current spec --- heat/core/dndarray.py | 12 +++--- heat/core/factories.py | 85 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 89 insertions(+), 8 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index f17716a05c..112cdc3ddf 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -648,7 +648,7 @@ def create_partition_interface(self): 'start': (0, 0, 0), 'shape': (7, 3, 2), 'data': tensor([...], dtype=torch.int32), - 'location': 0, + 'location': [0], 'dtype': torch.int32, 'device': 'cpu' }, @@ -656,7 +656,7 @@ def create_partition_interface(self): 'start': (7, 0, 0), 'shape': (7, 3, 2), 'data': None, - 'location': 1, + 'location': [1], 'dtype': torch.int32, 'device': 'cpu' }, @@ -664,7 +664,7 @@ def create_partition_interface(self): 'start': (14, 0, 0), 'shape': (7, 3, 2), 'data': None, - 'location': 2, + 'location': [2], 'dtype': torch.int32, 'device': 'cpu' }, @@ -672,12 +672,13 @@ def create_partition_interface(self): 'start': (21, 0, 0), 'shape': (6, 3, 2), 'data': None, - 'location': 3, + 'location': [3], 'dtype': torch.int32, 'device': 'cpu' } }, 'locals': [(rank, 0, 0)], + 'get': lambda x: x, } Returns @@ -708,12 +709,11 @@ def create_partition_interface(self): dat = None if r != self.comm.rank else self.larray else: dat = self.larray - partitions[tuple(base_key)] = { "start": tuple(start_idx_map[r].tolist()), "shape": tuple(lshape_map[r].tolist()), "data": dat, - "location": r, + "location": [r], "dtype": self.dtype.torch_type(), "device": self.device.torch_device, } diff --git a/heat/core/factories.py b/heat/core/factories.py index 1eb1ab92f3..f3be18feaf 100644 --- a/heat/core/factories.py +++ b/heat/core/factories.py @@ -34,6 +34,7 @@ "ones_like", "zeros", "zeros_like", + "from_partitioned", ] @@ -42,7 +43,7 @@ def arange( dtype: Optional[Type[datatype]] = None, split: Optional[int] = None, device: Optional[Union[str, Device]] = None, - comm: Optional[Communication] = None + comm: Optional[Communication] = None, ) -> DNDarray: """ Return evenly spaced values within a given interval. @@ -727,7 +728,7 @@ def __factory_like( device: Device, comm: Communication, order: str = "C", - **kwargs + **kwargs, ) -> DNDarray: """ Abstracted '...-like' factory function for HeAT :class:`~heat.core.dndarray.DNDarray` initialization @@ -1321,3 +1322,83 @@ def zeros_like( """ # TODO: implement 'K' option when torch.clone() fix to preserve memory layout is released. return __factory_like(a, dtype, split, zeros, device, comm, order=order) + + +def from_partitioned(x, comm: Optional[Communication] = None) -> DNDarray: + """ + Return a newly created DNDarray constructed from the '__partitioned__' attributed of the input object. + Memory of local partitions will be shared (zero-copy) as long as supported by data objects. + Currently supports numpy ndarrays and torch tensors as data objects. + Current limitations: + * Partitions must be ordered in the partition-grid by rank + * Only one split-axis + * Only one partition per rank + * Only SPMD-style __partitioned__ + + Parameters + ---------- + x : object + Requires x.__partitioned__ + comm: Communication, optional + Handle to the nodes holding distributed parts or copies of this array. + + See also + -------- + :func:`ht.core.DNDarray.create_partition_interface `. + + Raises + ------ + AttributeError + If not hasattr(x, "__partitioned__") or if underlying data has no dtype. + TypeError + If it finds an unsupported array types + RuntimeError + If other unsupported content is found. + + Examples + -------- + >>> import heat as ht + >>> a = ht.ones((44,55), split=0) + >>> b = ht.from_partitioned(a) + >>> assert (a==b).all() + >>> a[40] = 4711 + >>> assert (a==b).all() + """ + comm = sanitize_comm(comm) + parted = x.__partitioned__ + if "locals" not in parted: + raise RuntimeError("Non-SPMD __partitioned__ not supported") + + gshape = parted["shape"] + split = [x for x in range(len(gshape)) if x != 1] + if len(split) != 1: + raise RuntimeError("Only exactly one split-axis supported") + split = split[0] + + lparts = parted["locals"] + if len(lparts) != 1: + raise RuntimeError("Only exactly one partition per rank supported (yet)") + parts = parted["partitions"] + lpart = parted["get"](parts[lparts[0]]["data"]) + if isinstance(lpart, np.ndarray): + data = torch.from_numpy(lpart) + elif isinstance(lpart, torch.Tensor): + data = lpart + else: + raise TypeError(f"Only numpy arrays and torch tensors supported (not {type(lpart)}") + htype = types.canonical_heat_type(data.dtype) + + expected = { + int(x["location"][0]): ( + comm.chunk(gshape, split, x["location"][0])[1:], + (x["shape"], x["start"]), + ) + for x in parts.values() + } + if any(i > 0 and expected[i][1][1][split] < expected[i - 1][1][1][split] for i in expected): + raise RuntimeError("__partitioned__ supported only if partitions are ordered by rank") + balanced = all(x[0][0] == x[1][0] for x in expected.values()) + + return DNDarray( + data, gshape, htype, split, devices.sanitize_device(None), sanitize_comm(comm), balanced + ) From a0c8ab475fff62e9e21b3e35d0b37c68fa7a3285 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Thu, 23 Sep 2021 09:02:03 +0200 Subject: [PATCH 15/17] updating from_partitioned function --- heat/core/dndarray.py | 2 + heat/core/factories.py | 231 +++++++++++++++++++----------- heat/core/tests/test_factories.py | 61 ++++++++ 3 files changed, 212 insertions(+), 82 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 9a39e920a8..20a62ff5b3 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -721,6 +721,8 @@ def create_partition_interface(self): "get": lambda x: x, } + self.__partitions_dict__ = partition_dict + return partition_dict def __float__(self) -> DNDarray: diff --git a/heat/core/factories.py b/heat/core/factories.py index f3be18feaf..2445e1026d 100644 --- a/heat/core/factories.py +++ b/heat/core/factories.py @@ -17,7 +17,6 @@ from . import devices from . import types - __all__ = [ "arange", "array", @@ -25,6 +24,8 @@ "empty", "empty_like", "eye", + "from_partitioned", + "from_partition_dict", "full", "full_like", "linspace", @@ -34,7 +35,6 @@ "ones_like", "zeros", "zeros_like", - "from_partitioned", ] @@ -790,6 +790,153 @@ def __factory_like( return factory(shape, dtype=dtype, split=split, device=device, comm=comm, order=order, **kwargs) +def from_partitioned(x, comm: Optional[Communication] = None) -> DNDarray: + """ + Return a newly created DNDarray constructed from the '__partitioned__' attributed of the input object. + Memory of local partitions will be shared (zero-copy) as long as supported by data objects. + Currently supports numpy ndarrays and torch tensors as data objects. + Current limitations: + * Partitions must be ordered in the partition-grid by rank + * Only one split-axis + * Only one partition per rank + * Only SPMD-style __partitioned__ + + Parameters + ---------- + x : object + Requires x.__partitioned__ + comm: Communication, optional + Handle to the nodes holding distributed parts or copies of this array. + + See also + -------- + :func:`ht.core.DNDarray.create_partition_interface `. + + Raises + ------ + AttributeError + If not hasattr(x, "__partitioned__") or if underlying data has no dtype. + TypeError + If it finds an unsupported array types + RuntimeError + If other unsupported content is found. + + Examples + -------- + >>> import heat as ht + >>> a = ht.ones((44,55), split=0) + >>> b = ht.from_partitioned(a) + >>> assert (a==b).all() + >>> a[40] = 4711 + >>> assert (a==b).all() + """ + comm = sanitize_comm(comm) + parted = x.__partitioned__ + return __from_partition_dict_helper(parted, comm) + + +def from_partition_dict(parted: dict, comm: Optional[Communication] = None) -> DNDarray: + """ + Return a newly created DNDarray constructed from the '__partitioned__' attributed of the input object. + Memory of local partitions will be shared (zero-copy) as long as supported by data objects. + Currently supports numpy ndarrays and torch tensors as data objects. + Current limitations: + * Partitions must be ordered in the partition-grid by rank + * Only one split-axis + * Only one partition per rank + * Only SPMD-style __partitioned__ + + Parameters + ---------- + parted : dict + A partition dictionary used to create the new DNDarray + comm: Communication, optional + Handle to the nodes holding distributed parts or copies of this array. + + See also + -------- + :func:`ht.core.DNDarray.create_partition_interface `. + + Raises + ------ + AttributeError + If not hasattr(x, "__partitioned__") or if underlying data has no dtype. + TypeError + If it finds an unsupported array types + RuntimeError + If other unsupported content is found. + + Examples + -------- + >>> import heat as ht + >>> a = ht.ones((44,55), split=0) + >>> b = ht.from_partition_dict(a.__partitioned__) + >>> assert (a==b).all() + >>> a[40] = 4711 + >>> assert (a==b).all() + """ + comm = sanitize_comm(comm) + return __from_partition_dict_helper(parted, comm) + + +def __from_partition_dict_helper(parted: dict, comm: Communication): + # helper to create a DNDarray from a partition table (dictionary) + # the dictionary must be in the same form as the DNDarray.__partitioned__ property creates + if "locals" not in parted: + raise RuntimeError("Non-SPMD __partitioned__ not supported") + + rank = comm.rank + size = comm.size + # TODO: check that the partitions are in the correct places + # TODO: set the local data to be the one in the partion table + partitions = parted["partitions"] + for k in partitions.keys(): + if + ldata = + + gshape = parted["shape"] + if len(parted['partitions']) == 1: # this is the split=None case + split = None + else: + split = [x for x in range(len(gshape)) if x != 1] + if len(split) != 1: + raise RuntimeError("Only exactly one split-axis supported") + split = split[0] + + lparts = parted["locals"] + if len(lparts) != 1: + raise RuntimeError("Only exactly one partition per rank supported (yet)") + parts = parted["partitions"] + lpart = parted["get"](parts[lparts[0]]["data"]) + if isinstance(lpart, np.ndarray): + data = torch.from_numpy(lpart) + elif isinstance(lpart, torch.Tensor): + data = lpart + else: + raise TypeError(f"Only numpy arrays and torch tensors supported (not {type(lpart)}") + htype = types.canonical_heat_type(data.dtype) + + expected = { + int(x["location"][0]): ( + comm.chunk(gshape, split, x["location"][0])[1:], + (x["shape"], x["start"]), + ) + for x in parts.values() + } + if split is not None and \ + any(i > 0 and expected[i][1][1][split] < expected[i - 1][1][1][split] for i in expected + ): + raise RuntimeError("__partitioned__ supported only if partitions are ordered by rank") + balanced = all(x[0][0] == x[1][0] for x in expected.values()) + + ret = DNDarray( + data, gshape, htype, split, devices.sanitize_device(None), sanitize_comm(comm), balanced + ) + ret.__partitions_dict__ = parted + + return ret + + def full( shape: Union[int, Sequence[int]], fill_value: Union[int, float], @@ -1322,83 +1469,3 @@ def zeros_like( """ # TODO: implement 'K' option when torch.clone() fix to preserve memory layout is released. return __factory_like(a, dtype, split, zeros, device, comm, order=order) - - -def from_partitioned(x, comm: Optional[Communication] = None) -> DNDarray: - """ - Return a newly created DNDarray constructed from the '__partitioned__' attributed of the input object. - Memory of local partitions will be shared (zero-copy) as long as supported by data objects. - Currently supports numpy ndarrays and torch tensors as data objects. - Current limitations: - * Partitions must be ordered in the partition-grid by rank - * Only one split-axis - * Only one partition per rank - * Only SPMD-style __partitioned__ - - Parameters - ---------- - x : object - Requires x.__partitioned__ - comm: Communication, optional - Handle to the nodes holding distributed parts or copies of this array. - - See also - -------- - :func:`ht.core.DNDarray.create_partition_interface `. - - Raises - ------ - AttributeError - If not hasattr(x, "__partitioned__") or if underlying data has no dtype. - TypeError - If it finds an unsupported array types - RuntimeError - If other unsupported content is found. - - Examples - -------- - >>> import heat as ht - >>> a = ht.ones((44,55), split=0) - >>> b = ht.from_partitioned(a) - >>> assert (a==b).all() - >>> a[40] = 4711 - >>> assert (a==b).all() - """ - comm = sanitize_comm(comm) - parted = x.__partitioned__ - if "locals" not in parted: - raise RuntimeError("Non-SPMD __partitioned__ not supported") - - gshape = parted["shape"] - split = [x for x in range(len(gshape)) if x != 1] - if len(split) != 1: - raise RuntimeError("Only exactly one split-axis supported") - split = split[0] - - lparts = parted["locals"] - if len(lparts) != 1: - raise RuntimeError("Only exactly one partition per rank supported (yet)") - parts = parted["partitions"] - lpart = parted["get"](parts[lparts[0]]["data"]) - if isinstance(lpart, np.ndarray): - data = torch.from_numpy(lpart) - elif isinstance(lpart, torch.Tensor): - data = lpart - else: - raise TypeError(f"Only numpy arrays and torch tensors supported (not {type(lpart)}") - htype = types.canonical_heat_type(data.dtype) - - expected = { - int(x["location"][0]): ( - comm.chunk(gshape, split, x["location"][0])[1:], - (x["shape"], x["start"]), - ) - for x in parts.values() - } - if any(i > 0 and expected[i][1][1][split] < expected[i - 1][1][1][split] for i in expected): - raise RuntimeError("__partitioned__ supported only if partitions are ordered by rank") - balanced = all(x[0][0] == x[1][0] for x in expected.values()) - - return DNDarray( - data, gshape, htype, split, devices.sanitize_device(None), sanitize_comm(comm), balanced - ) diff --git a/heat/core/tests/test_factories.py b/heat/core/tests/test_factories.py index c996c06232..8759652fa3 100644 --- a/heat/core/tests/test_factories.py +++ b/heat/core/tests/test_factories.py @@ -490,6 +490,67 @@ def get_offset(tensor_array): self.assertEqual(eye.shape, shape) self.assertEqual(eye.split, 1) + def test_from_partitioned(self): + a = ht.zeros((120, 120), split=0) + # b = ht.from_partitioned(a, comm=a.comm) + # self.assertTrue(ht.equal(a, b)) + # # self.assertEqual(parted["shape"], (120, 120)) + # # self.assertEqual(parted["partition_tiling"], (a.comm.size, 1)) + # # self.assertEqual(parted["partitions"][(0, 0)]["start"], (0, 0)) + # + # a.resplit_(None) + # b = ht.from_partitioned(a, comm=a.comm) + # self.assertTrue(ht.equal(a, b)) + + a.resplit_(1) + b = ht.from_partitioned(a, comm=a.comm) + self.assertTrue(ht.equal(a, b)) + + # del b.__partitioned__["shape"] + # with self.assertRaises(RuntimeError): + # c = ht.from_partitioned(b) + # b.__partitions_dict__ = None + # _ = b.__partitioned__ + # + # del b.__partitioned__["locals"] + # with self.assertRaises(RuntimeError): + # c = ht.from_partitioned(b) + # b.__partitions_dict__ = None + # _ = b.__partitioned__ + # + # del b.__partitioned__["locals"] + # with self.assertRaises(RuntimeError): + # c = ht.from_partitioned(b) + # b.__partitions_dict__ = None + # _ = b.__partitioned__ + + def test_from_partition_dict(self): + a = ht.zeros((120, 120), split=0) + b = ht.from_partition_dict(a.__partitioned__, comm=a.comm) + self.assertTrue(ht.equal(a, b)) + + a.resplit_(None) + b = ht.from_partition_dict(a.__partitioned__, comm=a.comm) + self.assertTrue(ht.equal(a, b)) + + del b.__partitioned__["shape"] + with self.assertRaises(RuntimeError): + c = ht.from_partition_dict(b.__partitioned__) + b.__partitions_dict__ = None + _ = b.__partitioned__ + + del b.__partitioned__["locals"] + with self.assertRaises(RuntimeError): + c = ht.from_partition_dict(b.__partitioned__) + b.__partitions_dict__ = None + _ = b.__partitioned__ + + del b.__partitioned__["locals"] + with self.assertRaises(RuntimeError): + c = ht.from_partition_dict(b.__partitioned__) + b.__partitions_dict__ = None + _ = b.__partitioned__ + def test_full(self): # simple tensor data = ht.full((10, 2), 4) From 7c70eae80b502e0909d6e86d0a4f5c9338aaed75 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Thu, 23 Sep 2021 09:55:15 +0200 Subject: [PATCH 16/17] added nonzero split support to from partition dictionary, added tests, added factory function for building a dndarry from a partition dictionary --- heat/core/factories.py | 51 +++++++++++++------------- heat/core/tests/test_factories.py | 59 ++++++++++++++++--------------- 2 files changed, 57 insertions(+), 53 deletions(-) diff --git a/heat/core/factories.py b/heat/core/factories.py index 2445e1026d..01b2f9014b 100644 --- a/heat/core/factories.py +++ b/heat/core/factories.py @@ -884,26 +884,18 @@ def __from_partition_dict_helper(parted: dict, comm: Communication): # the dictionary must be in the same form as the DNDarray.__partitioned__ property creates if "locals" not in parted: raise RuntimeError("Non-SPMD __partitioned__ not supported") - - rank = comm.rank - size = comm.size - # TODO: check that the partitions are in the correct places - # TODO: set the local data to be the one in the partion table - partitions = parted["partitions"] - for k in partitions.keys(): - if - ldata = - - gshape = parted["shape"] - if len(parted['partitions']) == 1: # this is the split=None case - split = None - else: - split = [x for x in range(len(gshape)) if x != 1] - if len(split) != 1: - raise RuntimeError("Only exactly one split-axis supported") - split = split[0] - - lparts = parted["locals"] + try: + gshape = parted["shape"] + except KeyError: + raise RuntimeError( + "partition dictionary must have a 'shape' entry, see DNDarray.create_partition_interface for more details" + ) + try: + lparts = parted["locals"] + except KeyError: + raise RuntimeError( + "partition dictionary must have a 'local' entry, see DNDarray.create_partition_interface for more details" + ) if len(lparts) != 1: raise RuntimeError("Only exactly one partition per rank supported (yet)") parts = parted["partitions"] @@ -916,6 +908,21 @@ def __from_partition_dict_helper(parted: dict, comm: Communication): raise TypeError(f"Only numpy arrays and torch tensors supported (not {type(lpart)}") htype = types.canonical_heat_type(data.dtype) + # get split axis + gshape_list = list(gshape) + lshape_list = list(data.shape) + shape_diff = torch.tensor( + [g - l for g, l in zip(gshape_list, lshape_list)] + ) # dont care about device + nz = torch.nonzero(shape_diff) + + if nz.numel() > 1: + raise RuntimeError("only one split axis allowed, check the ") + elif nz.numel() == 1: + split = nz[0].item() + else: + split = None + expected = { int(x["location"][0]): ( comm.chunk(gshape, split, x["location"][0])[1:], @@ -923,10 +930,6 @@ def __from_partition_dict_helper(parted: dict, comm: Communication): ) for x in parts.values() } - if split is not None and \ - any(i > 0 and expected[i][1][1][split] < expected[i - 1][1][1][split] for i in expected - ): - raise RuntimeError("__partitioned__ supported only if partitions are ordered by rank") balanced = all(x[0][0] == x[1][0] for x in expected.values()) ret = DNDarray( diff --git a/heat/core/tests/test_factories.py b/heat/core/tests/test_factories.py index 8759652fa3..275f9f6722 100644 --- a/heat/core/tests/test_factories.py +++ b/heat/core/tests/test_factories.py @@ -492,62 +492,63 @@ def get_offset(tensor_array): def test_from_partitioned(self): a = ht.zeros((120, 120), split=0) - # b = ht.from_partitioned(a, comm=a.comm) - # self.assertTrue(ht.equal(a, b)) - # # self.assertEqual(parted["shape"], (120, 120)) - # # self.assertEqual(parted["partition_tiling"], (a.comm.size, 1)) - # # self.assertEqual(parted["partitions"][(0, 0)]["start"], (0, 0)) - # - # a.resplit_(None) - # b = ht.from_partitioned(a, comm=a.comm) - # self.assertTrue(ht.equal(a, b)) + b = ht.from_partitioned(a, comm=a.comm) + a[2, :] = 128 + self.assertTrue(ht.equal(a, b)) + + a.resplit_(None) + b = ht.from_partitioned(a, comm=a.comm) + self.assertTrue(ht.equal(a, b)) a.resplit_(1) b = ht.from_partitioned(a, comm=a.comm) + b[50] = 94 self.assertTrue(ht.equal(a, b)) - # del b.__partitioned__["shape"] - # with self.assertRaises(RuntimeError): - # c = ht.from_partitioned(b) - # b.__partitions_dict__ = None - # _ = b.__partitioned__ - # - # del b.__partitioned__["locals"] - # with self.assertRaises(RuntimeError): - # c = ht.from_partitioned(b) - # b.__partitions_dict__ = None - # _ = b.__partitioned__ - # - # del b.__partitioned__["locals"] - # with self.assertRaises(RuntimeError): - # c = ht.from_partitioned(b) - # b.__partitions_dict__ = None - # _ = b.__partitioned__ + del b.__partitioned__["shape"] + with self.assertRaises(RuntimeError): + _ = ht.from_partitioned(b) + b.__partitions_dict__ = None + _ = b.__partitioned__ + + del b.__partitioned__["locals"] + with self.assertRaises(RuntimeError): + _ = ht.from_partitioned(b) + b.__partitions_dict__ = None + _ = b.__partitioned__ + + del b.__partitioned__["locals"] + with self.assertRaises(RuntimeError): + _ = ht.from_partitioned(b) + b.__partitions_dict__ = None + _ = b.__partitioned__ def test_from_partition_dict(self): a = ht.zeros((120, 120), split=0) b = ht.from_partition_dict(a.__partitioned__, comm=a.comm) + a[0, 0] = 100 self.assertTrue(ht.equal(a, b)) a.resplit_(None) + a[0, 0] = 50 b = ht.from_partition_dict(a.__partitioned__, comm=a.comm) self.assertTrue(ht.equal(a, b)) del b.__partitioned__["shape"] with self.assertRaises(RuntimeError): - c = ht.from_partition_dict(b.__partitioned__) + _ = ht.from_partition_dict(b.__partitioned__) b.__partitions_dict__ = None _ = b.__partitioned__ del b.__partitioned__["locals"] with self.assertRaises(RuntimeError): - c = ht.from_partition_dict(b.__partitioned__) + _ = ht.from_partition_dict(b.__partitioned__) b.__partitions_dict__ = None _ = b.__partitioned__ del b.__partitioned__["locals"] with self.assertRaises(RuntimeError): - c = ht.from_partition_dict(b.__partitioned__) + _ = ht.from_partition_dict(b.__partitioned__) b.__partitions_dict__ = None _ = b.__partitioned__ From fd7ec8369da0b8c44d49321eace1fc603cbb482d Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 9 Feb 2023 05:31:20 +0100 Subject: [PATCH 17/17] Ensure is None when virtually resplitting to None on 1 process --- heat/core/dndarray.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 3def85f730..f06ba02bd7 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1387,6 +1387,8 @@ def resplit_(self, axis: int = None): # early out for unchanged content if self.comm.size == 1: self.__split = axis + if axis is None: + self.__partitions_dict__ = None if axis == self.split: return self