Skip to content

Commit 50fb4a9

Browse files
Add :include-subclasses: option to sphinx.ext.inheritance_diagram (#8159)
When this option is given, all sub-classes of the classes listed as arguments to `inheritance-diagram` are included in the diagram. This makes it possible to generate a complete inheritance tree by just listing the base class and adding `:include-subclasses:`. This is different from specifying the module that contains this base class, as this module might include other classes that should not be part of the diagram. Co-authored-by: Adam Turner <[email protected]>
1 parent 34519be commit 50fb4a9

File tree

3 files changed

+52
-15
lines changed

3 files changed

+52
-15
lines changed

CHANGES.rst

+3
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ Features added
9191
* #12507: Add the :ref:`collapsible <collapsible-admonitions>` option
9292
to admonition directives.
9393
Patch by Chris Sewell.
94+
* #8191, #8159: Add :rst:dir:`inheritance-diagram:include-subclasses` option to
95+
the :rst:dir:`inheritance-diagram` directive.
96+
Patch by Walter Dörwald.
9497

9598
Bugs fixed
9699
----------

doc/usage/extensions/inheritance.rst

+18
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,24 @@ It adds this directive:
100100
.. versionchanged:: 1.7
101101
Added ``top-classes`` option to limit the scope of inheritance graphs.
102102

103+
.. rst:directive:option:: include-subclasses
104+
:type: no value
105+
106+
.. versionadded:: 8.2
107+
108+
If given, any subclass of the classes will be added to the diagram too.
109+
110+
Given the Python module from above, you can specify
111+
your inheritance diagram like this:
112+
113+
.. code-block:: rst
114+
115+
.. inheritance-diagram:: dummy.test.A
116+
:include-subclasses:
117+
118+
This will include the classes A, B, C, D, E and F in the inheritance diagram
119+
but no other classes in the module ``dummy.test``.
120+
103121
104122
Examples
105123
--------

sphinx/ext/inheritance_diagram.py

+31-15
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class E(B): pass
5252
from sphinx.util.docutils import SphinxDirective
5353

5454
if TYPE_CHECKING:
55-
from collections.abc import Iterable, Sequence
55+
from collections.abc import Collection, Iterable, Iterator, Sequence, Set
5656
from typing import Any, ClassVar, Final
5757

5858
from docutils.nodes import Node
@@ -106,7 +106,7 @@ def try_import(objname: str) -> Any:
106106
return None
107107

108108

109-
def import_classes(name: str, currmodule: str) -> Any:
109+
def import_classes(name: str, currmodule: str) -> list[type[Any]]:
110110
"""Import a class using its fully-qualified *name*."""
111111
target = None
112112

@@ -156,37 +156,45 @@ def __init__(
156156
private_bases: bool = False,
157157
parts: int = 0,
158158
aliases: dict[str, str] | None = None,
159-
top_classes: Sequence[Any] = (),
159+
top_classes: Set[str] = frozenset(),
160+
include_subclasses: bool = False,
160161
) -> None:
161162
"""*class_names* is a list of child classes to show bases from.
162163
163164
If *show_builtins* is True, then Python builtins will be shown
164165
in the graph.
165166
"""
166167
self.class_names = class_names
167-
classes = self._import_classes(class_names, currmodule)
168+
classes: Collection[type[Any]] = self._import_classes(class_names, currmodule)
169+
if include_subclasses:
170+
classes_set = {*classes}
171+
for cls in tuple(classes_set):
172+
classes_set.update(_subclasses(cls))
173+
classes = classes_set
168174
self.class_info = self._class_info(
169175
classes, show_builtins, private_bases, parts, aliases, top_classes
170176
)
171177
if not self.class_info:
172178
msg = 'No classes found for inheritance diagram'
173179
raise InheritanceException(msg)
174180

175-
def _import_classes(self, class_names: list[str], currmodule: str) -> list[Any]:
181+
def _import_classes(
182+
self, class_names: list[str], currmodule: str
183+
) -> Sequence[type[Any]]:
176184
"""Import a list of classes."""
177-
classes: list[Any] = []
185+
classes: list[type[Any]] = []
178186
for name in class_names:
179187
classes.extend(import_classes(name, currmodule))
180188
return classes
181189

182190
def _class_info(
183191
self,
184-
classes: list[Any],
192+
classes: Collection[type[Any]],
185193
show_builtins: bool,
186194
private_bases: bool,
187195
parts: int,
188196
aliases: dict[str, str] | None,
189-
top_classes: Sequence[Any],
197+
top_classes: Set[str],
190198
) -> list[tuple[str, str, Sequence[str], str | None]]:
191199
"""Return name and bases for all classes that are ancestors of
192200
*classes*.
@@ -205,7 +213,7 @@ def _class_info(
205213
"""
206214
all_classes = {}
207215

208-
def recurse(cls: Any) -> None:
216+
def recurse(cls: type[Any]) -> None:
209217
if not show_builtins and cls in PY_BUILTINS:
210218
return
211219
if not private_bases and cls.__name__.startswith('_'):
@@ -248,7 +256,7 @@ def recurse(cls: Any) -> None:
248256
]
249257

250258
def class_name(
251-
self, cls: Any, parts: int = 0, aliases: dict[str, str] | None = None
259+
self, cls: type[Any], parts: int = 0, aliases: dict[str, str] | None = None
252260
) -> str:
253261
"""Given a class object, return a fully-qualified name.
254262
@@ -377,6 +385,7 @@ class InheritanceDiagram(SphinxDirective):
377385
'private-bases': directives.flag,
378386
'caption': directives.unchanged,
379387
'top-classes': directives.unchanged_required,
388+
'include-subclasses': directives.flag,
380389
}
381390

382391
def run(self) -> list[Node]:
@@ -387,11 +396,11 @@ def run(self) -> list[Node]:
387396
# Store the original content for use as a hash
388397
node['parts'] = self.options.get('parts', 0)
389398
node['content'] = ', '.join(class_names)
390-
node['top-classes'] = []
391-
for cls in self.options.get('top-classes', '').split(','):
392-
cls = cls.strip()
393-
if cls:
394-
node['top-classes'].append(cls)
399+
node['top-classes'] = frozenset({
400+
cls_stripped
401+
for cls in self.options.get('top-classes', '').split(',')
402+
if (cls_stripped := cls.strip())
403+
})
395404

396405
# Create a graph starting with the list of classes
397406
try:
@@ -402,6 +411,7 @@ def run(self) -> list[Node]:
402411
private_bases='private-bases' in self.options,
403412
aliases=self.config.inheritance_alias,
404413
top_classes=node['top-classes'],
414+
include_subclasses='include-subclasses' in self.options,
405415
)
406416
except InheritanceException as err:
407417
return [node.document.reporter.warning(err, line=self.lineno)]
@@ -428,6 +438,12 @@ def run(self) -> list[Node]:
428438
return [figure]
429439

430440

441+
def _subclasses(cls: type[Any]) -> Iterator[type[Any]]:
442+
yield cls
443+
for sub_cls in cls.__subclasses__():
444+
yield from _subclasses(sub_cls)
445+
446+
431447
def get_graph_hash(node: inheritance_diagram) -> str:
432448
encoded = (node['content'] + str(node['parts'])).encode()
433449
return hashlib.md5(encoded, usedforsecurity=False).hexdigest()[-10:]

0 commit comments

Comments
 (0)