@@ -52,7 +52,7 @@ class E(B): pass
52
52
from sphinx .util .docutils import SphinxDirective
53
53
54
54
if TYPE_CHECKING :
55
- from collections .abc import Iterable , Sequence
55
+ from collections .abc import Collection , Iterable , Iterator , Sequence , Set
56
56
from typing import Any , ClassVar , Final
57
57
58
58
from docutils .nodes import Node
@@ -106,7 +106,7 @@ def try_import(objname: str) -> Any:
106
106
return None
107
107
108
108
109
- def import_classes (name : str , currmodule : str ) -> Any :
109
+ def import_classes (name : str , currmodule : str ) -> list [ type [ Any ]] :
110
110
"""Import a class using its fully-qualified *name*."""
111
111
target = None
112
112
@@ -156,37 +156,45 @@ def __init__(
156
156
private_bases : bool = False ,
157
157
parts : int = 0 ,
158
158
aliases : dict [str , str ] | None = None ,
159
- top_classes : Sequence [Any ] = (),
159
+ top_classes : Set [str ] = frozenset (),
160
+ include_subclasses : bool = False ,
160
161
) -> None :
161
162
"""*class_names* is a list of child classes to show bases from.
162
163
163
164
If *show_builtins* is True, then Python builtins will be shown
164
165
in the graph.
165
166
"""
166
167
self .class_names = class_names
167
- classes = self ._import_classes (class_names , currmodule )
168
+ classes : Collection [type [Any ]] = self ._import_classes (class_names , currmodule )
169
+ if include_subclasses :
170
+ classes_set = {* classes }
171
+ for cls in tuple (classes_set ):
172
+ classes_set .update (_subclasses (cls ))
173
+ classes = classes_set
168
174
self .class_info = self ._class_info (
169
175
classes , show_builtins , private_bases , parts , aliases , top_classes
170
176
)
171
177
if not self .class_info :
172
178
msg = 'No classes found for inheritance diagram'
173
179
raise InheritanceException (msg )
174
180
175
- def _import_classes (self , class_names : list [str ], currmodule : str ) -> list [Any ]:
181
+ def _import_classes (
182
+ self , class_names : list [str ], currmodule : str
183
+ ) -> Sequence [type [Any ]]:
176
184
"""Import a list of classes."""
177
- classes : list [Any ] = []
185
+ classes : list [type [ Any ] ] = []
178
186
for name in class_names :
179
187
classes .extend (import_classes (name , currmodule ))
180
188
return classes
181
189
182
190
def _class_info (
183
191
self ,
184
- classes : list [ Any ],
192
+ classes : Collection [ type [ Any ] ],
185
193
show_builtins : bool ,
186
194
private_bases : bool ,
187
195
parts : int ,
188
196
aliases : dict [str , str ] | None ,
189
- top_classes : Sequence [ Any ],
197
+ top_classes : Set [ str ],
190
198
) -> list [tuple [str , str , Sequence [str ], str | None ]]:
191
199
"""Return name and bases for all classes that are ancestors of
192
200
*classes*.
@@ -205,7 +213,7 @@ def _class_info(
205
213
"""
206
214
all_classes = {}
207
215
208
- def recurse (cls : Any ) -> None :
216
+ def recurse (cls : type [ Any ] ) -> None :
209
217
if not show_builtins and cls in PY_BUILTINS :
210
218
return
211
219
if not private_bases and cls .__name__ .startswith ('_' ):
@@ -248,7 +256,7 @@ def recurse(cls: Any) -> None:
248
256
]
249
257
250
258
def class_name (
251
- self , cls : Any , parts : int = 0 , aliases : dict [str , str ] | None = None
259
+ self , cls : type [ Any ] , parts : int = 0 , aliases : dict [str , str ] | None = None
252
260
) -> str :
253
261
"""Given a class object, return a fully-qualified name.
254
262
@@ -377,6 +385,7 @@ class InheritanceDiagram(SphinxDirective):
377
385
'private-bases' : directives .flag ,
378
386
'caption' : directives .unchanged ,
379
387
'top-classes' : directives .unchanged_required ,
388
+ 'include-subclasses' : directives .flag ,
380
389
}
381
390
382
391
def run (self ) -> list [Node ]:
@@ -387,11 +396,11 @@ def run(self) -> list[Node]:
387
396
# Store the original content for use as a hash
388
397
node ['parts' ] = self .options .get ('parts' , 0 )
389
398
node ['content' ] = ', ' .join (class_names )
390
- node ['top-classes' ] = []
391
- for cls in self . options . get ( 'top-classes' , '' ). split ( ',' ):
392
- cls = cls . strip ( )
393
- if cls :
394
- node [ 'top-classes' ]. append ( cls )
399
+ node ['top-classes' ] = frozenset ({
400
+ cls_stripped
401
+ for cls in self . options . get ( 'top-classes' , '' ). split ( ',' )
402
+ if ( cls_stripped := cls . strip ())
403
+ } )
395
404
396
405
# Create a graph starting with the list of classes
397
406
try :
@@ -402,6 +411,7 @@ def run(self) -> list[Node]:
402
411
private_bases = 'private-bases' in self .options ,
403
412
aliases = self .config .inheritance_alias ,
404
413
top_classes = node ['top-classes' ],
414
+ include_subclasses = 'include-subclasses' in self .options ,
405
415
)
406
416
except InheritanceException as err :
407
417
return [node .document .reporter .warning (err , line = self .lineno )]
@@ -428,6 +438,12 @@ def run(self) -> list[Node]:
428
438
return [figure ]
429
439
430
440
441
+ def _subclasses (cls : type [Any ]) -> Iterator [type [Any ]]:
442
+ yield cls
443
+ for sub_cls in cls .__subclasses__ ():
444
+ yield from _subclasses (sub_cls )
445
+
446
+
431
447
def get_graph_hash (node : inheritance_diagram ) -> str :
432
448
encoded = (node ['content' ] + str (node ['parts' ])).encode ()
433
449
return hashlib .md5 (encoded , usedforsecurity = False ).hexdigest ()[- 10 :]
0 commit comments