diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index ee778f6bfd9..6106400ebdd 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -12,7 +12,7 @@ jobs: benchmark: if: ${{ contains( github.event.pull_request.labels.*.name, 'run-benchmark') && github.event_name == 'pull_request' || contains( github.event.pull_request.labels.*.name, 'topic-performance') && github.event_name == 'pull_request' || github.event_name == 'workflow_dispatch' }} name: Linux - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 env: ASV_DIR: "./asv_bench" CONDA_ENV_FILE: ci/requirements/environment.yml diff --git a/xarray/core/datatree_render.py b/xarray/core/datatree_render.py index 11336cd9689..f1042d9eeef 100644 --- a/xarray/core/datatree_render.py +++ b/xarray/core/datatree_render.py @@ -10,6 +10,7 @@ from collections import namedtuple from collections.abc import Iterable, Iterator +from math import ceil from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -79,6 +80,7 @@ def __init__( style=None, childiter: type = list, maxlevel: int | None = None, + maxchildren: int | None = None, ): """ Render tree starting at `node`. @@ -88,6 +90,7 @@ def __init__( Iterables that change the order of children cannot be used (e.g., `reversed`). maxlevel: Limit rendering to this depth. + maxchildren: Limit number of children at each node. :any:`RenderDataTree` is an iterator, returning a tuple with 3 items: `pre` tree prefix. @@ -160,6 +163,16 @@ def __init__( root ├── sub0 └── sub1 + + # `maxchildren` limits the number of children per node + + >>> print(RenderDataTree(root, maxchildren=1).by_attr("name")) + root + ├── sub0 + │ ├── sub0B + │ ... + ... + """ if style is None: style = ContStyle() @@ -169,24 +182,44 @@ def __init__( self.style = style self.childiter = childiter self.maxlevel = maxlevel + self.maxchildren = maxchildren def __iter__(self) -> Iterator[Row]: return self.__next(self.node, tuple()) def __next( - self, node: DataTree, continues: tuple[bool, ...], level: int = 0 + self, + node: DataTree, + continues: tuple[bool, ...], + level: int = 0, ) -> Iterator[Row]: yield RenderDataTree.__item(node, continues, self.style) children = node.children.values() level += 1 if children and (self.maxlevel is None or level < self.maxlevel): + nchildren = len(children) children = self.childiter(children) - for child, is_last in _is_last(children): - yield from self.__next(child, continues + (not is_last,), level=level) + for i, (child, is_last) in enumerate(_is_last(children)): + if ( + self.maxchildren is None + or i < ceil(self.maxchildren / 2) + or i >= ceil(nchildren - self.maxchildren / 2) + ): + yield from self.__next( + child, + continues + (not is_last,), + level=level, + ) + if ( + self.maxchildren is not None + and nchildren > self.maxchildren + and i == ceil(self.maxchildren / 2) + ): + yield RenderDataTree.__item("...", continues, self.style) @staticmethod def __item( - node: DataTree, continues: tuple[bool, ...], style: AbstractStyle + node: DataTree | str, continues: tuple[bool, ...], style: AbstractStyle ) -> Row: if not continues: return Row("", "", node) @@ -244,6 +277,9 @@ def by_attr(self, attrname: str = "name") -> str: def get() -> Iterator[str]: for pre, fill, node in self: + if isinstance(node, str): + yield f"{fill}{node}" + continue attr = ( attrname(node) if callable(attrname) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index e10bc14292c..f713f535647 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -1139,14 +1139,21 @@ def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str: def datatree_repr(dt: DataTree) -> str: """A printable representation of the structure of this entire tree.""" - renderer = RenderDataTree(dt) + max_children = OPTIONS["display_max_children"] + + renderer = RenderDataTree(dt, maxchildren=max_children) name_info = "" if dt.name is None else f" {dt.name!r}" header = f"" lines = [header] show_inherited = True + for pre, fill, node in renderer: + if isinstance(node, str): + lines.append(f"{fill}{node}") + continue + node_repr = _datatree_node_repr(node, show_inherited=show_inherited) show_inherited = False # only show inherited coords on the root diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index eb9073cd869..c0601e3326a 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -6,7 +6,8 @@ from functools import lru_cache, partial from html import escape from importlib.resources import files -from typing import TYPE_CHECKING +from math import ceil +from typing import TYPE_CHECKING, Literal from xarray.core.formatting import ( inherited_vars, @@ -14,7 +15,7 @@ inline_variable_array_repr, short_data_repr, ) -from xarray.core.options import _get_boolean_with_default +from xarray.core.options import OPTIONS, _get_boolean_with_default STATIC_FILES = ( ("xarray.static.html", "icons-svg-inline.html"), @@ -192,7 +193,13 @@ def collapsible_section( def _mapping_section( - mapping, name, details_func, max_items_collapse, expand_option_name, enabled=True + mapping, + name, + details_func, + max_items_collapse, + expand_option_name, + enabled=True, + max_option_name: Literal["display_max_children"] | None = None, ) -> str: n_items = len(mapping) expanded = _get_boolean_with_default( @@ -200,8 +207,15 @@ def _mapping_section( ) collapsed = not expanded + inline_details = "" + if max_option_name and max_option_name in OPTIONS: + max_items = int(OPTIONS[max_option_name]) + if n_items > max_items: + inline_details = f"({max_items}/{n_items})" + return collapsible_section( name, + inline_details=inline_details, details=details_func(mapping), n_items=n_items, enabled=enabled, @@ -348,26 +362,23 @@ def dataset_repr(ds) -> str: def summarize_datatree_children(children: Mapping[str, DataTree]) -> str: - N_CHILDREN = len(children) - 1 - - # Get result from datatree_node_repr and wrap it - lines_callback = lambda n, c, end: _wrap_datatree_repr( - datatree_node_repr(n, c), end=end - ) - - children_html = "".join( - ( - lines_callback(n, c, end=False) # Long lines - if i < N_CHILDREN - else lines_callback(n, c, end=True) - ) # Short lines - for i, (n, c) in enumerate(children.items()) - ) + MAX_CHILDREN = OPTIONS["display_max_children"] + n_children = len(children) + + children_html = [] + for i, (n, c) in enumerate(children.items()): + if i < ceil(MAX_CHILDREN / 2) or i >= ceil(n_children - MAX_CHILDREN / 2): + is_last = i == (n_children - 1) + children_html.append( + _wrap_datatree_repr(datatree_node_repr(n, c), end=is_last) + ) + elif n_children > MAX_CHILDREN and i == ceil(MAX_CHILDREN / 2): + children_html.append("
...
") return "".join( [ "
", - children_html, + "".join(children_html), "
", ] ) @@ -378,6 +389,7 @@ def summarize_datatree_children(children: Mapping[str, DataTree]) -> str: name="Groups", details_func=summarize_datatree_children, max_items_collapse=1, + max_option_name="display_max_children", expand_option_name="display_expand_groups", ) diff --git a/xarray/core/options.py b/xarray/core/options.py index 2d69e4b6584..adaa563d09b 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -13,6 +13,7 @@ "chunk_manager", "cmap_divergent", "cmap_sequential", + "display_max_children", "display_max_rows", "display_values_threshold", "display_style", @@ -40,6 +41,7 @@ class T_Options(TypedDict): chunk_manager: str cmap_divergent: str | Colormap cmap_sequential: str | Colormap + display_max_children: int display_max_rows: int display_values_threshold: int display_style: Literal["text", "html"] @@ -67,6 +69,7 @@ class T_Options(TypedDict): "chunk_manager": "dask", "cmap_divergent": "RdBu_r", "cmap_sequential": "viridis", + "display_max_children": 6, "display_max_rows": 12, "display_values_threshold": 200, "display_style": "html", @@ -99,6 +102,7 @@ def _positive_integer(value: Any) -> bool: _VALIDATORS = { "arithmetic_broadcast": lambda value: isinstance(value, bool), "arithmetic_join": _JOIN_OPTIONS.__contains__, + "display_max_children": _positive_integer, "display_max_rows": _positive_integer, "display_values_threshold": _positive_integer, "display_style": _DISPLAY_OPTIONS.__contains__, @@ -222,6 +226,8 @@ class set_options: * ``True`` : to always expand indexes * ``False`` : to always collapse indexes * ``default`` : to expand unless over a pre-defined limit (always collapse for html style) + display_max_children : int, default: 6 + Maximum number of children to display for each node in a DataTree. display_max_rows : int, default: 12 Maximum display rows. display_values_threshold : int, default: 200 diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index ac222636cbf..23bc194695c 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1219,6 +1219,76 @@ def test_repr_two_children(self) -> None: ).strip() assert result == expected + def test_repr_truncates_nodes(self) -> None: + # construct a datatree with 50 nodes + number_of_files = 10 + number_of_groups = 5 + tree_dict = {} + for f in range(number_of_files): + for g in range(number_of_groups): + tree_dict[f"file_{f}/group_{g}"] = Dataset({"g": f * g}) + + tree = DataTree.from_dict(tree_dict) + with xr.set_options(display_max_children=3): + result = repr(tree) + + expected = dedent( + """ + + Group: / + ├── Group: /file_0 + │ ├── Group: /file_0/group_0 + │ │ Dimensions: () + │ │ Data variables: + │ │ g int64 8B 0 + │ ├── Group: /file_0/group_1 + │ │ Dimensions: () + │ │ Data variables: + │ │ g int64 8B 0 + │ ... + │ └── Group: /file_0/group_4 + │ Dimensions: () + │ Data variables: + │ g int64 8B 0 + ├── Group: /file_1 + │ ├── Group: /file_1/group_0 + │ │ Dimensions: () + │ │ Data variables: + │ │ g int64 8B 0 + │ ├── Group: /file_1/group_1 + │ │ Dimensions: () + │ │ Data variables: + │ │ g int64 8B 1 + │ ... + │ └── Group: /file_1/group_4 + │ Dimensions: () + │ Data variables: + │ g int64 8B 4 + ... + └── Group: /file_9 + ├── Group: /file_9/group_0 + │ Dimensions: () + │ Data variables: + │ g int64 8B 0 + ├── Group: /file_9/group_1 + │ Dimensions: () + │ Data variables: + │ g int64 8B 9 + ... + └── Group: /file_9/group_4 + Dimensions: () + Data variables: + g int64 8B 36 + """ + ).strip() + assert expected == result + + with xr.set_options(display_max_children=10): + result = repr(tree) + + for key in tree_dict: + assert key in result + def test_repr_inherited_dims(self) -> None: tree = DataTree.from_dict( { diff --git a/xarray/tests/test_formatting_html.py b/xarray/tests/test_formatting_html.py index 7c9cdbeaaf5..a17fffc3683 100644 --- a/xarray/tests/test_formatting_html.py +++ b/xarray/tests/test_formatting_html.py @@ -320,6 +320,52 @@ def test_two_children( ) +class TestDataTreeTruncatesNodes: + def test_many_nodes(self) -> None: + # construct a datatree with 500 nodes + number_of_files = 20 + number_of_groups = 25 + tree_dict = {} + for f in range(number_of_files): + for g in range(number_of_groups): + tree_dict[f"file_{f}/group_{g}"] = xr.Dataset({"g": f * g}) + + tree = xr.DataTree.from_dict(tree_dict) + with xr.set_options(display_style="html"): + result = tree._repr_html_() + + assert "6/20" in result + for i in range(number_of_files): + if i < 3 or i >= (number_of_files - 3): + assert f"file_{i}" in result + else: + assert f"file_{i}" not in result + + assert "6/25" in result + for i in range(number_of_groups): + if i < 3 or i >= (number_of_groups - 3): + assert f"group_{i}" in result + else: + assert f"group_{i}" not in result + + with xr.set_options(display_style="html", display_max_children=3): + result = tree._repr_html_() + + assert "3/20" in result + for i in range(number_of_files): + if i < 2 or i >= (number_of_files - 1): + assert f"file_{i}" in result + else: + assert f"file_{i}" not in result + + assert "3/25" in result + for i in range(number_of_groups): + if i < 2 or i >= (number_of_groups - 1): + assert f"group_{i}" in result + else: + assert f"group_{i}" not in result + + class TestDataTreeInheritance: def test_inherited_section_present(self) -> None: dt = xr.DataTree.from_dict(