Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shorten text repr for DataTree #10139

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions xarray/core/datatree_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from collections import namedtuple
from collections.abc import Iterable, Iterator
from math import ceil
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(
style=None,
childiter: type = list,
maxlevel: int | None = None,
maxchildren: int | None = None,
):
"""
Render tree starting at `node`.
Expand All @@ -88,6 +90,7 @@ def __init__(
Iterables that change the order of children cannot be used
(e.g., `reversed`).
maxlevel: Limit rendering to this depth.
maxchildren: Limit number of children at each node.
:any:`RenderDataTree` is an iterator, returning a tuple with 3 items:
`pre`
tree prefix.
Expand Down Expand Up @@ -160,6 +163,16 @@ def __init__(
root
├── sub0
└── sub1

# `maxchildren` limits the number of children per node

>>> print(RenderDataTree(root, maxchildren=1).by_attr("name"))
root
├── sub0
│ ├── sub0B
│ ...
...

"""
if style is None:
style = ContStyle()
Expand All @@ -169,20 +182,40 @@ def __init__(
self.style = style
self.childiter = childiter
self.maxlevel = maxlevel
self.maxchildren = maxchildren

def __iter__(self) -> Iterator[Row]:
return self.__next(self.node, tuple())

def __next(
self, node: DataTree, continues: tuple[bool, ...], level: int = 0
self,
node: DataTree,
continues: tuple[bool, ...],
level: int = 0,
) -> Iterator[Row]:
yield RenderDataTree.__item(node, continues, self.style)
children = node.children.values()
level += 1
if children and (self.maxlevel is None or level < self.maxlevel):
nchildren = len(children)
children = self.childiter(children)
for child, is_last in _is_last(children):
yield from self.__next(child, continues + (not is_last,), level=level)
for i, (child, is_last) in enumerate(_is_last(children)):
if (
self.maxchildren is None
or i < ceil(self.maxchildren / 2)
or i >= ceil(nchildren - self.maxchildren / 2)
):
yield from self.__next(
child,
continues + (not is_last,),
level=level,
)
if (
self.maxchildren is not None
and nchildren > self.maxchildren
and i == ceil(self.maxchildren / 2)
):
yield RenderDataTree.__item("...", continues, self.style)

@staticmethod
def __item(
Expand Down Expand Up @@ -244,6 +277,9 @@ def by_attr(self, attrname: str = "name") -> str:

def get() -> Iterator[str]:
for pre, fill, node in self:
if isinstance(node, str):
yield f"{fill}{node}"
continue
attr = (
attrname(node)
if callable(attrname)
Expand Down
9 changes: 8 additions & 1 deletion xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,14 +1137,21 @@ def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str:

def datatree_repr(dt: DataTree) -> str:
"""A printable representation of the structure of this entire tree."""
renderer = RenderDataTree(dt)
max_children = OPTIONS["display_max_children"]

renderer = RenderDataTree(dt, maxchildren=max_children)

name_info = "" if dt.name is None else f" {dt.name!r}"
header = f"<xarray.DataTree{name_info}>"

lines = [header]
show_inherited = True

for pre, fill, node in renderer:
if isinstance(node, str):
lines.append(f"{fill}{node}")
continue

node_repr = _datatree_node_repr(node, show_inherited=show_inherited)
show_inherited = False # only show inherited coords on the root

Expand Down
53 changes: 35 additions & 18 deletions xarray/core/formatting_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import lru_cache, partial
from html import escape
from importlib.resources import files
from math import ceil
from typing import TYPE_CHECKING

from xarray.core.formatting import (
Expand All @@ -14,7 +15,7 @@
inline_variable_array_repr,
short_data_repr,
)
from xarray.core.options import _get_boolean_with_default
from xarray.core.options import OPTIONS, _get_boolean_with_default

STATIC_FILES = (
("xarray.static.html", "icons-svg-inline.html"),
Expand Down Expand Up @@ -192,16 +193,26 @@ def collapsible_section(


def _mapping_section(
mapping, name, details_func, max_items_collapse, expand_option_name, enabled=True
mapping,
name,
details_func,
max_items_collapse,
expand_option_name,
enabled=True,
max_option_name: str | None = None,
) -> str:
n_items = len(mapping)
expanded = _get_boolean_with_default(
expand_option_name, n_items < max_items_collapse
)
collapsed = not expanded
max_items = OPTIONS.get(max_option_name)
truncated = max_items is not None and n_items > max_items
inline_details = f"({max_items}/{n_items})" if truncated else ""

return collapsible_section(
name,
inline_details=inline_details,
details=details_func(mapping),
n_items=n_items,
enabled=enabled,
Expand Down Expand Up @@ -348,26 +359,31 @@ def dataset_repr(ds) -> str:


def summarize_datatree_children(children: Mapping[str, DataTree]) -> str:
N_CHILDREN = len(children) - 1

# Get result from datatree_node_repr and wrap it
lines_callback = lambda n, c, end: _wrap_datatree_repr(
datatree_node_repr(n, c), end=end
)

children_html = "".join(
(
lines_callback(n, c, end=False) # Long lines
if i < N_CHILDREN
else lines_callback(n, c, end=True)
) # Short lines
for i, (n, c) in enumerate(children.items())
)
MAX_CHILDREN = OPTIONS["display_max_children"]
n_children = len(children)

children_html = []
for i, (n, c) in enumerate(children.items()):
if (
MAX_CHILDREN is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't MAX_CHILDREN always an int? It's defined like that in OPTIONS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was thinking originally that it should be nullable but I didn't see any other nullable options so then I was worried that that might be an antipattern. The case for null is to allow arbitrarily large trees, which obviously has performance implications that triggered this work in the first place. All that is to say: yes. It is always an int. I'll change the code to make that explicit.

or i < ceil(MAX_CHILDREN / 2)
or i >= ceil(n_children - MAX_CHILDREN / 2)
):
is_last = i == (n_children - 1)
children_html.append(
_wrap_datatree_repr(datatree_node_repr(n, c), end=is_last)
)
elif (
MAX_CHILDREN is not None
and n_children > MAX_CHILDREN
and i == ceil(MAX_CHILDREN / 2)
):
children_html.append("<div>...</div>")

return "".join(
[
"<div style='display: inline-grid; grid-template-columns: 100%; grid-column: 1 / -1'>",
children_html,
"".join(children_html),
"</div>",
]
)
Expand All @@ -378,6 +394,7 @@ def summarize_datatree_children(children: Mapping[str, DataTree]) -> str:
name="Groups",
details_func=summarize_datatree_children,
max_items_collapse=1,
max_option_name="display_max_children",
expand_option_name="display_expand_groups",
)

Expand Down
6 changes: 6 additions & 0 deletions xarray/core/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"chunk_manager",
"cmap_divergent",
"cmap_sequential",
"display_max_children",
"display_max_rows",
"display_values_threshold",
"display_style",
Expand Down Expand Up @@ -40,6 +41,7 @@ class T_Options(TypedDict):
chunk_manager: str
cmap_divergent: str | Colormap
cmap_sequential: str | Colormap
display_max_children: int
display_max_rows: int
display_values_threshold: int
display_style: Literal["text", "html"]
Expand Down Expand Up @@ -67,6 +69,7 @@ class T_Options(TypedDict):
"chunk_manager": "dask",
"cmap_divergent": "RdBu_r",
"cmap_sequential": "viridis",
"display_max_children": 6,
"display_max_rows": 12,
"display_values_threshold": 200,
"display_style": "html",
Expand Down Expand Up @@ -99,6 +102,7 @@ def _positive_integer(value: Any) -> bool:
_VALIDATORS = {
"arithmetic_broadcast": lambda value: isinstance(value, bool),
"arithmetic_join": _JOIN_OPTIONS.__contains__,
"display_max_children": _positive_integer,
"display_max_rows": _positive_integer,
"display_values_threshold": _positive_integer,
"display_style": _DISPLAY_OPTIONS.__contains__,
Expand Down Expand Up @@ -222,6 +226,8 @@ class set_options:
* ``True`` : to always expand indexes
* ``False`` : to always collapse indexes
* ``default`` : to expand unless over a pre-defined limit (always collapse for html style)
display_max_children : int, default: 6
Maximum number of children to display for each node in a DataTree
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Maximum number of children to display for each node in a DataTree
Maximum number of children to display for each node in a DataTree.

display_max_rows : int, default: 12
Maximum display rows.
display_values_threshold : int, default: 200
Expand Down
70 changes: 70 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,76 @@ def test_repr_two_children(self) -> None:
).strip()
assert result == expected

def test_repr_truncates_nodes(self) -> None:
# construct a datatree with 50 nodes
number_of_files = 10
number_of_groups = 5
tree_dict = {}
for f in range(number_of_files):
for g in range(number_of_groups):
tree_dict[f"file_{f}/group_{g}"] = Dataset({"g": f * g})

tree = DataTree.from_dict(tree_dict)
with xr.set_options(display_max_children=3):
result = repr(tree)

expected = dedent(
"""
<xarray.DataTree>
Group: /
├── Group: /file_0
│ ├── Group: /file_0/group_0
│ │ Dimensions: ()
│ │ Data variables:
│ │ g int64 8B 0
│ ├── Group: /file_0/group_1
│ │ Dimensions: ()
│ │ Data variables:
│ │ g int64 8B 0
│ ...
│ └── Group: /file_0/group_4
│ Dimensions: ()
│ Data variables:
│ g int64 8B 0
├── Group: /file_1
│ ├── Group: /file_1/group_0
│ │ Dimensions: ()
│ │ Data variables:
│ │ g int64 8B 0
│ ├── Group: /file_1/group_1
│ │ Dimensions: ()
│ │ Data variables:
│ │ g int64 8B 1
│ ...
│ └── Group: /file_1/group_4
│ Dimensions: ()
│ Data variables:
│ g int64 8B 4
...
└── Group: /file_9
├── Group: /file_9/group_0
│ Dimensions: ()
│ Data variables:
│ g int64 8B 0
├── Group: /file_9/group_1
│ Dimensions: ()
│ Data variables:
│ g int64 8B 9
...
└── Group: /file_9/group_4
Dimensions: ()
Data variables:
g int64 8B 36
"""
).strip()
assert expected == result

with xr.set_options(display_max_children=10):
result = repr(tree)

for key in tree_dict:
assert key in result

def test_repr_inherited_dims(self) -> None:
tree = DataTree.from_dict(
{
Expand Down
46 changes: 46 additions & 0 deletions xarray/tests/test_formatting_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,52 @@ def test_two_children(
)


class TestDataTreeTruncatesNodes:
def test_many_nodes(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_many_nodes(self):
def test_many_nodes(self) -> None:

# construct a datatree with 500 nodes
number_of_files = 20
number_of_groups = 25
tree_dict = {}
for f in range(number_of_files):
for g in range(number_of_groups):
tree_dict[f"file_{f}/group_{g}"] = xr.Dataset({"g": f * g})

tree = xr.DataTree.from_dict(tree_dict)
with xr.set_options(display_style="html"):
result = tree._repr_html_()

assert "6/20" in result
for i in range(number_of_files):
if i < 3 or i >= (number_of_files - 3):
assert f"file_{i}</div>" in result
else:
assert f"file_{i}</div>" not in result

assert "6/25" in result
for i in range(number_of_groups):
if i < 3 or i >= (number_of_groups - 3):
assert f"group_{i}</div>" in result
else:
assert f"group_{i}</div>" not in result

with xr.set_options(display_style="html", display_max_children=3):
result = tree._repr_html_()

assert "3/20" in result
for i in range(number_of_files):
if i < 2 or i >= (number_of_files - 1):
assert f"file_{i}</div>" in result
else:
assert f"file_{i}</div>" not in result

assert "3/25" in result
for i in range(number_of_groups):
if i < 2 or i >= (number_of_groups - 1):
assert f"group_{i}</div>" in result
else:
assert f"group_{i}</div>" not in result


class TestDataTreeInheritance:
def test_inherited_section_present(self) -> None:
dt = xr.DataTree.from_dict(
Expand Down
Loading