Skip to content

Commit 1e96b24

Browse files
ax3lnavinpvdaltairwalterskarndevsizmailov
committed
Fix class ordering to preserve inheritance hierarchy (pybind#231)
Classes in generated .pyi stubs were sorted alphabetically, causing derived classes to appear before their base classes and breaking type checkers. Three changes fix this: - Parser: use module.__dict__.items() instead of inspect.getmembers() to preserve the pybind11 registration order (definition order) - Printer: replace alphabetical sort with a configurable _order_classes() dispatch supporting "definition" (default), "topological" (Kahn's algorithm ensuring bases precede derived classes), and "alphabetical" - CLI: add --sort-by option to select the class ordering strategy The topological sort ignores external bases (from other modules) and breaks ties by input position for deterministic output. Cyclic cross- references between classes (e.g. aliases, method signatures) are not inheritance cycles and are already handled by `from __future__ import annotations` in the generated stubs. Closes pybind#231 Based on the approaches in PR pybind#275 by @juelg and PR pybind#294 by @daltairwalter, informed by review feedback from @skarndev and @sizmailov. Co-Authored-By: juelg <[email protected]> Co-Authored-By: daltairwalter <[email protected]> Co-Authored-By: skarndev <[email protected]> Co-Authored-By: sizmailov <[email protected]> Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
1 parent fcd02aa commit 1e96b24

3 files changed

Lines changed: 86 additions & 5 deletions

File tree

pybind11_stubgen/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class CLIArgs(Namespace):
7777
exit_code: bool
7878
dry_run: bool
7979
stub_extension: str
80+
sort_by: str
8081
module_names: list[str]
8182

8283

@@ -216,6 +217,16 @@ def regex_colon_path(regex_path: str) -> tuple[re.Pattern, str]:
216217
"Must be 'pyi' (default) or 'py'",
217218
)
218219

220+
parser.add_argument(
221+
"--sort-by",
222+
type=str,
223+
default="definition",
224+
choices=["definition", "topological"],
225+
help="Order of classes in generated stubs. "
226+
"'definition' (default) preserves the order from the module. "
227+
"'topological' sorts by inheritance hierarchy.",
228+
)
229+
219230
parser.add_argument(
220231
"module_names",
221232
metavar="MODULE_NAMES",
@@ -310,7 +321,10 @@ def main(argv: Sequence[str] | None = None) -> None:
310321
args = arg_parser().parse_args(argv, namespace=CLIArgs())
311322

312323
parser = stub_parser_from_args(args)
313-
printer = Printer(invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is)
324+
printer = Printer(
325+
invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is,
326+
sort_by=args.sort_by,
327+
)
314328

315329
run(
316330
parser,

pybind11_stubgen/parser/mixins/parse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def handle_module(
8989
self, path: QualifiedName, module: types.ModuleType
9090
) -> Module | None:
9191
result = Module(name=path[-1])
92-
for name, member in inspect.getmembers(module):
92+
for name, member in module.__dict__.items():
9393
obj = self.handle_module_member(
9494
QualifiedName([*path, Identifier(name)]), module, member
9595
)

pybind11_stubgen/printer.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import dataclasses
4+
import logging
45
import sys
56

67
from pybind11_stubgen.structs import (
@@ -25,13 +26,79 @@
2526
)
2627

2728

29+
log = logging.getLogger("pybind11_stubgen")
30+
31+
2832
def indent_lines(lines: list[str], by=4) -> list[str]:
2933
return [" " * by + line for line in lines]
3034

3135

36+
def _topological_sort_classes(classes: list[Class]) -> list[Class]:
37+
"""Sort classes so that base classes appear before derived classes.
38+
39+
Uses Kahn's algorithm. Ties are broken by input position for stability.
40+
External bases (not in the current scope) are ignored.
41+
"""
42+
if not classes:
43+
return classes
44+
45+
name_to_index = {c.name: i for i, c in enumerate(classes)}
46+
name_to_class = {c.name: c for c in classes}
47+
48+
# Build adjacency list: base -> [derived, ...]
49+
# and in-degree count for each class
50+
children: dict[str, list[str]] = {c.name: [] for c in classes}
51+
in_degree: dict[str, int] = {c.name: 0 for c in classes}
52+
53+
for c in classes:
54+
for base in c.bases:
55+
base_name = str(base[-1])
56+
if base_name in name_to_class:
57+
children[base_name].append(c.name)
58+
in_degree[c.name] += 1
59+
60+
# Initialize queue with zero in-degree classes, sorted by input position
61+
queue = sorted(
62+
[name for name, deg in in_degree.items() if deg == 0],
63+
key=lambda n: name_to_index[n],
64+
)
65+
66+
result = []
67+
while queue:
68+
name = queue.pop(0)
69+
result.append(name_to_class[name])
70+
# Sort children by input position for stable ordering
71+
for child in sorted(children[name], key=lambda n: name_to_index[n]):
72+
in_degree[child] -= 1
73+
if in_degree[child] == 0:
74+
queue.append(child)
75+
# Re-sort queue to maintain input-position priority
76+
queue.sort(key=lambda n: name_to_index[n])
77+
78+
if len(result) < len(classes):
79+
remaining = [c for c in classes if c.name not in {r.name for r in result}]
80+
log.warning(
81+
"Cycle detected in class inheritance involving: %s. "
82+
"Appending in original order.",
83+
[c.name for c in remaining],
84+
)
85+
result.extend(remaining)
86+
87+
return result
88+
89+
3290
class Printer:
33-
def __init__(self, invalid_expr_as_ellipses: bool):
91+
def __init__(self, invalid_expr_as_ellipses: bool, sort_by: str = "definition"):
3492
self.invalid_expr_as_ellipses = invalid_expr_as_ellipses
93+
self.sort_by = sort_by
94+
95+
def _order_classes(self, classes: list[Class]) -> list[Class]:
96+
if self.sort_by == "alphabetical":
97+
return sorted(classes, key=lambda c: c.name)
98+
elif self.sort_by == "definition":
99+
return classes
100+
else: # "topological"
101+
return _topological_sort_classes(classes)
35102

36103
def print_alias(self, alias: Alias) -> list[str]:
37104
return [f"{alias.name} = {alias.origin}"]
@@ -90,7 +157,7 @@ def print_class_body(self, class_: Class) -> list[str]:
90157
if class_.doc is not None:
91158
result.extend(self.print_docstring(class_.doc))
92159

93-
for sub_class in sorted(class_.classes, key=lambda c: c.name):
160+
for sub_class in self._order_classes(class_.classes):
94161
result.extend(self.print_class(sub_class))
95162

96163
modifier_order: dict[Modifier, int] = {
@@ -232,7 +299,7 @@ def print_module(self, module: Module) -> list[str]:
232299
for type_var in sorted(module.type_vars, key=lambda t: t.name):
233300
result.extend(self.print_type_var(type_var))
234301

235-
for class_ in sorted(module.classes, key=lambda c: c.name):
302+
for class_ in self._order_classes(module.classes):
236303
result.extend(self.print_class(class_))
237304

238305
for func in sorted(module.functions, key=lambda f: f.name):

0 commit comments

Comments
 (0)