Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.

Commit d76f050

Browse files
authored
Add assign method (#181)
* WIP assign method * remove rogue prints * remove assign from mapped methods * moved assign in docs to reflect it only acting on single nodes now * whatsnew
1 parent 4eb70a6 commit d76f050

File tree

5 files changed

+70
-3
lines changed

5 files changed

+70
-3
lines changed

Diff for: datatree/datatree.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from xarray.core.indexes import Index, Indexes
3131
from xarray.core.merge import dataset_update_method
3232
from xarray.core.options import OPTIONS as XR_OPTS
33-
from xarray.core.utils import Default, Frozen, _default
33+
from xarray.core.utils import Default, Frozen, _default, either_dict_or_kwargs
3434
from xarray.core.variable import Variable, calculate_dimensions
3535

3636
from . import formatting, formatting_html
@@ -821,6 +821,48 @@ def update(self, other: Dataset | Mapping[str, DataTree | DataArray]) -> None:
821821
inplace=True, children=merged_children, **vars_merge_result._asdict()
822822
)
823823

824+
def assign(
825+
self, items: Mapping[Any, Any] | None = None, **items_kwargs: Any
826+
) -> DataTree:
827+
"""
828+
Assign new data variables or child nodes to a DataTree, returning a new object
829+
with all the original items in addition to the new ones.
830+
831+
Parameters
832+
----------
833+
items : mapping of hashable to Any
834+
Mapping from variable or child node names to the new values. If the new values
835+
are callable, they are computed on the Dataset and assigned to new
836+
data variables. If the values are not callable, (e.g. a DataTree, DataArray,
837+
scalar, or array), they are simply assigned.
838+
**items_kwargs
839+
The keyword arguments form of ``variables``.
840+
One of variables or variables_kwargs must be provided.
841+
842+
Returns
843+
-------
844+
dt : DataTree
845+
A new DataTree with the new variables or children in addition to all the
846+
existing items.
847+
848+
Notes
849+
-----
850+
Since ``kwargs`` is a dictionary, the order of your arguments may not
851+
be preserved, and so the order of the new variables is not well-defined.
852+
Assigning multiple items within the same ``assign`` is
853+
possible, but you cannot reference other variables created within the
854+
same ``assign`` call.
855+
856+
See Also
857+
--------
858+
xarray.Dataset.assign
859+
pandas.DataFrame.assign
860+
"""
861+
items = either_dict_or_kwargs(items, items_kwargs, "assign")
862+
dt = self.copy()
863+
dt.update(items)
864+
return dt
865+
824866
def drop_nodes(
825867
self: DataTree, names: str | Iterable[str], *, errors: ErrorOptions = "raise"
826868
) -> DataTree:

Diff for: datatree/ops.py

-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@
6868
"combine_first",
6969
"reduce",
7070
"map",
71-
"assign",
7271
"diff",
7372
"shift",
7473
"roll",

Diff for: datatree/tests/test_datatree.py

+23
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,17 @@ def test_getitem_dict_like_selection_access_to_dataset(self):
206206

207207

208208
class TestUpdate:
209+
def test_update(self):
210+
dt = DataTree()
211+
dt.update({"foo": xr.DataArray(0), "a": DataTree()})
212+
expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "a": None})
213+
print(dt)
214+
print(dt.children)
215+
print(dt._children)
216+
print(dt["a"])
217+
print(expected)
218+
dtt.assert_equal(dt, expected)
219+
209220
def test_update_new_named_dataarray(self):
210221
da = xr.DataArray(name="temp", data=[0, 50])
211222
folder1 = DataTree(name="folder1")
@@ -542,6 +553,18 @@ def test_drop_nodes(self):
542553
childless = dropped.drop_nodes(names=["Mary", "Ashley"], errors="ignore")
543554
assert childless.children == {}
544555

556+
def test_assign(self):
557+
dt = DataTree()
558+
expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "/a": None})
559+
560+
# kwargs form
561+
result = dt.assign(foo=xr.DataArray(0), a=DataTree())
562+
dtt.assert_equal(result, expected)
563+
564+
# dict form
565+
result = dt.assign({"foo": xr.DataArray(0), "a": DataTree()})
566+
dtt.assert_equal(result, expected)
567+
545568

546569
class TestPipe:
547570
def test_noop(self, create_test_datatree):

Diff for: docs/source/api.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ Manipulate the contents of all nodes in a tree simultaneously.
109109
:toctree: generated/
110110

111111
DataTree.copy
112-
DataTree.assign
112+
113113
DataTree.assign_coords
114114
DataTree.merge
115115
DataTree.rename
@@ -130,6 +130,7 @@ Manipulate the contents of a single DataTree node.
130130
.. autosummary::
131131
:toctree: generated/
132132

133+
DataTree.assign
133134
DataTree.drop_nodes
134135

135136
Comparisons

Diff for: docs/source/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ Breaking changes
3939
By `Tom Nicholas <https://github.com/TomNicholas>`_.
4040
- Grafting a subtree onto another tree now leaves name of original subtree object unchanged (:issue:`116`, :pull:`172`, :pull:`178`).
4141
By `Tom Nicholas <https://github.com/TomNicholas>`_.
42+
- Changed the :py:meth:`DataTree.assign` method to just work on the local node (:pull:`181`).
43+
By `Tom Nicholas <https://github.com/TomNicholas>`_.
4244

4345
Deprecations
4446
~~~~~~~~~~~~

0 commit comments

Comments
 (0)