Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.

Commit b7cdaa0

Browse files
authoredJun 17, 2022
Allow assigning Coercible values + NamedNode internal class (#115)
* test assigning int * allow assigning coercible values * refactor name-related methods to intermediate class * refactor tests to match * fix now-exposed bug with naming * moved test showing lack of name permanence * whatsnew
1 parent 4cb72fb commit b7cdaa0

File tree

5 files changed

+164
-126
lines changed

5 files changed

+164
-126
lines changed
 

‎datatree/datatree.py

+10-21
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
MappedDataWithCoords,
4242
)
4343
from .render import RenderTree
44-
from .treenode import NodePath, Tree, TreeNode
44+
from .treenode import NamedNode, NodePath, Tree
4545

4646
if TYPE_CHECKING:
4747
from xarray.core.merge import CoercibleValue
@@ -228,7 +228,7 @@ def _replace(
228228

229229

230230
class DataTree(
231-
TreeNode,
231+
NamedNode,
232232
MappedDatasetMethodsMixin,
233233
MappedDataWithCoords,
234234
DataTreeArithmeticMixin,
@@ -343,20 +343,6 @@ def __init__(
343343
)
344344
self._close = ds._close
345345

346-
@property
347-
def name(self) -> str | None:
348-
"""The name of this node."""
349-
return self._name
350-
351-
@name.setter
352-
def name(self, name: str | None) -> None:
353-
if name is not None:
354-
if not isinstance(name, str):
355-
raise TypeError("node name must be a string or None")
356-
if "/" in name:
357-
raise ValueError("node names cannot contain forward slashes")
358-
self._name = name
359-
360346
@property
361347
def parent(self: DataTree) -> DataTree | None:
362348
"""Parent of this node."""
@@ -699,14 +685,17 @@ def _set(self, key: str, val: DataTree | CoercibleValue) -> None:
699685
if isinstance(val, DataTree):
700686
val.name = key
701687
val.parent = self
702-
elif isinstance(val, (DataArray, Variable)):
703-
# TODO this should also accomodate other types that can be coerced into Variables
704-
self.update({key: val})
705688
else:
706-
raise TypeError(f"Type {type(val)} cannot be assigned to a DataTree")
689+
if not isinstance(val, (DataArray, Variable)):
690+
# accommodate other types that can be coerced into Variables
691+
val = DataArray(val)
692+
693+
self.update({key: val})
707694

708695
def __setitem__(
709-
self, key: str, value: DataTree | Dataset | DataArray | Variable
696+
self,
697+
key: str,
698+
value: Any,
710699
) -> None:
711700
"""
712701
Add either a child node or an array to the tree, at any position.

‎datatree/tests/test_datatree.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -249,13 +249,6 @@ def test_setitem_unnamed_child_node_becomes_named(self):
249249
john2["sonny"] = DataTree()
250250
assert john2["sonny"].name == "sonny"
251251

252-
@pytest.mark.xfail(reason="bug with name overwriting")
253-
def test_setitem_child_node_keeps_name(self):
254-
john = DataTree(name="john")
255-
r2d2 = DataTree(name="R2D2")
256-
john["Mary"] = r2d2
257-
assert r2d2.name == "R2D2"
258-
259252
def test_setitem_new_grandchild_node(self):
260253
john = DataTree(name="john")
261254
DataTree(name="mary", parent=john)
@@ -314,6 +307,17 @@ def test_setitem_unnamed_dataarray(self):
314307
folder1["results"] = data
315308
xrt.assert_equal(folder1["results"], data)
316309

310+
def test_setitem_variable(self):
311+
var = xr.Variable(data=[0, 50], dims="x")
312+
folder1 = DataTree(name="folder1")
313+
folder1["results"] = var
314+
xrt.assert_equal(folder1["results"], xr.DataArray(var))
315+
316+
def test_setitem_coerce_to_dataarray(self):
317+
folder1 = DataTree(name="folder1")
318+
folder1["results"] = 0
319+
xrt.assert_equal(folder1["results"], xr.DataArray(0))
320+
317321
def test_setitem_add_new_variable_to_empty_node(self):
318322
results = DataTree(name="results")
319323
results["pressure"] = xr.DataArray(data=[2, 3])

‎datatree/tests/test_treenode.py

+72-59
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22

33
from datatree.iterators import LevelOrderIter, PreOrderIter
4-
from datatree.treenode import TreeError, TreeNode
4+
from datatree.treenode import NamedNode, TreeError, TreeNode
55

66

77
class TestFamilyTree:
@@ -143,43 +143,6 @@ def test_get_from_root(self):
143143
assert sue._get_item("/Mary") is mary
144144

145145

146-
class TestPaths:
147-
def test_path_property(self):
148-
sue = TreeNode()
149-
mary = TreeNode(children={"Sue": sue})
150-
john = TreeNode(children={"Mary": mary}) # noqa
151-
assert sue.path == "/Mary/Sue"
152-
assert john.path == "/"
153-
154-
def test_path_roundtrip(self):
155-
sue = TreeNode()
156-
mary = TreeNode(children={"Sue": sue})
157-
john = TreeNode(children={"Mary": mary}) # noqa
158-
assert john._get_item(sue.path) == sue
159-
160-
def test_same_tree(self):
161-
mary = TreeNode()
162-
kate = TreeNode()
163-
john = TreeNode(children={"Mary": mary, "Kate": kate}) # noqa
164-
assert mary.same_tree(kate)
165-
166-
def test_relative_paths(self):
167-
sue = TreeNode()
168-
mary = TreeNode(children={"Sue": sue})
169-
annie = TreeNode()
170-
john = TreeNode(children={"Mary": mary, "Annie": annie})
171-
172-
assert sue.relative_to(john) == "Mary/Sue"
173-
assert john.relative_to(sue) == "../.."
174-
assert annie.relative_to(sue) == "../../Annie"
175-
assert sue.relative_to(annie) == "../Mary/Sue"
176-
assert sue.relative_to(sue) == "."
177-
178-
evil_kate = TreeNode()
179-
with pytest.raises(ValueError, match="nodes do not lie within the same tree"):
180-
sue.relative_to(evil_kate)
181-
182-
183146
class TestSetNodes:
184147
def test_set_child_node(self):
185148
john = TreeNode()
@@ -261,16 +224,66 @@ def test_del_child(self):
261224
del john["Mary"]
262225

263226

227+
class TestNames:
228+
def test_child_gets_named_on_attach(self):
229+
sue = NamedNode()
230+
mary = NamedNode(children={"Sue": sue}) # noqa
231+
assert sue.name == "Sue"
232+
233+
@pytest.mark.xfail(reason="requires refactoring to retain name")
234+
def test_grafted_subtree_retains_name(self):
235+
subtree = NamedNode("original")
236+
root = NamedNode(children={"new_name": subtree}) # noqa
237+
assert subtree.name == "original"
238+
239+
240+
class TestPaths:
241+
def test_path_property(self):
242+
sue = NamedNode()
243+
mary = NamedNode(children={"Sue": sue})
244+
john = NamedNode(children={"Mary": mary}) # noqa
245+
assert sue.path == "/Mary/Sue"
246+
assert john.path == "/"
247+
248+
def test_path_roundtrip(self):
249+
sue = NamedNode()
250+
mary = NamedNode(children={"Sue": sue})
251+
john = NamedNode(children={"Mary": mary}) # noqa
252+
assert john._get_item(sue.path) == sue
253+
254+
def test_same_tree(self):
255+
mary = NamedNode()
256+
kate = NamedNode()
257+
john = NamedNode(children={"Mary": mary, "Kate": kate}) # noqa
258+
assert mary.same_tree(kate)
259+
260+
def test_relative_paths(self):
261+
sue = NamedNode()
262+
mary = NamedNode(children={"Sue": sue})
263+
annie = NamedNode()
264+
john = NamedNode(children={"Mary": mary, "Annie": annie})
265+
266+
assert sue.relative_to(john) == "Mary/Sue"
267+
assert john.relative_to(sue) == "../.."
268+
assert annie.relative_to(sue) == "../../Annie"
269+
assert sue.relative_to(annie) == "../Mary/Sue"
270+
assert sue.relative_to(sue) == "."
271+
272+
evil_kate = NamedNode()
273+
with pytest.raises(ValueError, match="nodes do not lie within the same tree"):
274+
sue.relative_to(evil_kate)
275+
276+
264277
def create_test_tree():
265-
f = TreeNode()
266-
b = TreeNode()
267-
a = TreeNode()
268-
d = TreeNode()
269-
c = TreeNode()
270-
e = TreeNode()
271-
g = TreeNode()
272-
i = TreeNode()
273-
h = TreeNode()
278+
f = NamedNode()
279+
b = NamedNode()
280+
a = NamedNode()
281+
d = NamedNode()
282+
c = NamedNode()
283+
e = NamedNode()
284+
g = NamedNode()
285+
i = NamedNode()
286+
h = NamedNode()
274287

275288
f.children = {"b": b, "g": g}
276289
b.children = {"a": a, "d": d}
@@ -286,7 +299,7 @@ def test_preorderiter(self):
286299
tree = create_test_tree()
287300
result = [node.name for node in PreOrderIter(tree)]
288301
expected = [
289-
None, # root TreeNode is unnamed
302+
None, # root Node is unnamed
290303
"b",
291304
"a",
292305
"d",
@@ -302,7 +315,7 @@ def test_levelorderiter(self):
302315
tree = create_test_tree()
303316
result = [node.name for node in LevelOrderIter(tree)]
304317
expected = [
305-
None, # root TreeNode is unnamed
318+
None, # root Node is unnamed
306319
"b",
307320
"g",
308321
"a",
@@ -317,19 +330,19 @@ def test_levelorderiter(self):
317330

318331
class TestRenderTree:
319332
def test_render_nodetree(self):
320-
sam = TreeNode()
321-
ben = TreeNode()
322-
mary = TreeNode(children={"Sam": sam, "Ben": ben})
323-
kate = TreeNode()
324-
john = TreeNode(children={"Mary": mary, "Kate": kate})
333+
sam = NamedNode()
334+
ben = NamedNode()
335+
mary = NamedNode(children={"Sam": sam, "Ben": ben})
336+
kate = NamedNode()
337+
john = NamedNode(children={"Mary": mary, "Kate": kate})
325338

326339
printout = john.__str__()
327340
expected_nodes = [
328-
"TreeNode()",
329-
"TreeNode('Mary')",
330-
"TreeNode('Sam')",
331-
"TreeNode('Ben')",
332-
"TreeNode('Kate')",
341+
"NamedNode()",
342+
"NamedNode('Mary')",
343+
"NamedNode('Sam')",
344+
"NamedNode('Ben')",
345+
"NamedNode('Kate')",
333346
]
334347
for expected_node, printed_node in zip(expected_nodes, printout.splitlines()):
335348
assert expected_node in printed_node

‎datatree/treenode.py

+65-39
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def _check_loop(self, new_parent: Tree | None) -> None:
113113

114114
if self._is_descendant_of(new_parent):
115115
raise TreeError(
116-
f"Cannot set parent, as node {new_parent.name} is already a descendant of this node."
116+
"Cannot set parent, as intended parent is already a descendant of this node."
117117
)
118118

119119
def _is_descendant_of(self, node: Tree) -> bool:
@@ -461,18 +461,72 @@ def update(self: Tree, other: Mapping[str, Tree]) -> None:
461461
new_children = {**self.children, **other}
462462
self.children = new_children
463463

464+
def same_tree(self, other: Tree) -> bool:
465+
"""True if other node is in the same tree as this node."""
466+
return self.root is other.root
467+
468+
def find_common_ancestor(self, other: Tree) -> Tree:
469+
"""
470+
Find the first common ancestor of two nodes in the same tree.
471+
472+
Raise ValueError if they are not in the same tree.
473+
"""
474+
common_ancestor = None
475+
for node in other.iter_lineage():
476+
if node in self.ancestors:
477+
common_ancestor = node
478+
break
479+
480+
if not common_ancestor:
481+
raise ValueError(
482+
"Cannot find relative path because nodes do not lie within the same tree"
483+
)
484+
485+
return common_ancestor
486+
487+
def _path_to_ancestor(self, ancestor: Tree) -> NodePath:
488+
generation_gap = list(self.lineage).index(ancestor)
489+
path_upwards = "../" * generation_gap if generation_gap > 0 else "/"
490+
return NodePath(path_upwards)
491+
492+
493+
class NamedNode(TreeNode, Generic[Tree]):
494+
"""
495+
A TreeNode which knows its own name.
496+
497+
Implements path-like relationships to other nodes in its tree.
498+
"""
499+
500+
_name: Optional[str]
501+
_parent: Optional[Tree]
502+
_children: OrderedDict[str, Tree]
503+
504+
def __init__(self, name=None, children=None):
505+
super().__init__(children=children)
506+
self._name = None
507+
self.name = name
508+
464509
@property
465510
def name(self) -> str | None:
466-
"""If node has a parent, this is the key under which it is stored in `parent.children`."""
467-
if self.parent:
468-
return next(
469-
name for name, child in self.parent.children.items() if child is self
470-
)
471-
else:
472-
return None
511+
"""The name of this node."""
512+
return self._name
513+
514+
@name.setter
515+
def name(self, name: str | None) -> None:
516+
if name is not None:
517+
if not isinstance(name, str):
518+
raise TypeError("node name must be a string or None")
519+
if "/" in name:
520+
raise ValueError("node names cannot contain forward slashes")
521+
self._name = name
473522

474523
def __str__(self) -> str:
475-
return f"TreeNode({self.name})" if self.name else "TreeNode()"
524+
return f"NamedNode({self.name})" if self.name else "NamedNode()"
525+
526+
def _post_attach(self: NamedNode, parent: NamedNode) -> None:
527+
"""Ensures child has name attribute corresponding to key under which it has been stored."""
528+
key = next(k for k, v in parent.children.items() if v is self)
529+
self.name = key
476530

477531
@property
478532
def path(self) -> str:
@@ -483,9 +537,9 @@ def path(self) -> str:
483537
root, *ancestors = self.ancestors
484538
# don't include name of root because (a) root might not have a name & (b) we want path relative to root.
485539
names = [node.name for node in ancestors]
486-
return "/" + "/".join(names) # type: ignore
540+
return "/" + "/".join(names)
487541

488-
def relative_to(self, other: Tree) -> str:
542+
def relative_to(self: NamedNode, other: NamedNode) -> str:
489543
"""
490544
Compute the relative path from this node to node `other`.
491545
@@ -505,31 +559,3 @@ def relative_to(self, other: Tree) -> str:
505559
return str(
506560
path_to_common_ancestor / this_path.relative_to(common_ancestor.path)
507561
)
508-
509-
def same_tree(self, other: Tree) -> bool:
510-
"""True if other node is in the same tree as this node."""
511-
return self.root is other.root
512-
513-
def find_common_ancestor(self, other: Tree) -> Tree:
514-
"""
515-
Find the first common ancestor of two nodes in the same tree.
516-
517-
Raise ValueError if they are not in the same tree.
518-
"""
519-
common_ancestor = None
520-
for node in other.iter_lineage():
521-
if node in self.ancestors:
522-
common_ancestor = node
523-
break
524-
525-
if not common_ancestor:
526-
raise ValueError(
527-
"Cannot find relative path because nodes do not lie within the same tree"
528-
)
529-
530-
return common_ancestor
531-
532-
def _path_to_ancestor(self, ancestor: Tree) -> NodePath:
533-
generation_gap = list(self.lineage).index(ancestor)
534-
path_upwards = "../" * generation_gap if generation_gap > 0 else "/"
535-
return NodePath(path_upwards)

‎docs/source/whats-new.rst

+6
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ Bug fixes
4242
- Modifying the contents of a ``DataTree`` object via the ``DataTree.ds`` attribute is now forbidden, which prevents
4343
any possibility of the contents of a ``DataTree`` object and its ``.ds`` attribute diverging. (:issue:`38`, :pull:`99`)
4444
By `Tom Nicholas <https://github.com/TomNicholas>`_.
45+
- Fixed a bug so that names of children now always match keys under which parents store them (:pull:`99`).
46+
By `Tom Nicholas <https://github.com/TomNicholas>`_.
4547

4648
Documentation
4749
~~~~~~~~~~~~~
@@ -56,10 +58,14 @@ Internal Changes
5658
This approach means that the ``DataTree`` class now effectively copies and extends the internal structure of
5759
``xarray.Dataset``. (:pull:`41`)
5860
By `Tom Nicholas <https://github.com/TomNicholas>`_.
61+
- Refactored to use intermediate ``NamedNode`` class, separating implementation of methods requiring a ``name``
62+
attribute from those not requiring it.
63+
By `Tom Nicholas <https://github.com/TomNicholas>`_.
5964
- Made ``testing.test_datatree.create_test_datatree`` into a pytest fixture (:pull:`107`).
6065
By `Benjamin Woods <https://github.com/benjaminwoods>`_.
6166

6267

68+
6369
.. _whats-new.v0.0.6:
6470

6571
v0.0.6 (06/03/2022)

0 commit comments

Comments
 (0)
This repository has been archived.