Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pybind11-stubgen [-h]
[--print-invalid-expressions-as-is]
[--print-safe-value-reprs REGEX]
[--exit-code]
[--sort-by {definition,topological}]
[--stub-extension EXT]
MODULE_NAME [MODULE_NAMES ...]
```
Expand Down
16 changes: 15 additions & 1 deletion pybind11_stubgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class CLIArgs(Namespace):
exit_code: bool
dry_run: bool
stub_extension: str
sort_by: str
module_names: list[str]


Expand Down Expand Up @@ -216,6 +217,16 @@ def regex_colon_path(regex_path: str) -> tuple[re.Pattern, str]:
"Must be 'pyi' (default) or 'py'",
)

parser.add_argument(
"--sort-by",
type=str,
default="definition",
choices=["definition", "topological"],
help="Order of classes in generated stubs. "
"'definition' (default) preserves the order from the module. "
"'topological' sorts by inheritance hierarchy.",
)

parser.add_argument(
"module_names",
metavar="MODULE_NAMES",
Expand Down Expand Up @@ -310,7 +321,10 @@ def main(argv: Sequence[str] | None = None) -> None:
args = arg_parser().parse_args(argv, namespace=CLIArgs())

parser = stub_parser_from_args(args)
printer = Printer(invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is)
printer = Printer(
invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is,
sort_by=args.sort_by,
)

run(
parser,
Expand Down
6 changes: 2 additions & 4 deletions pybind11_stubgen/parser/mixins/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def handle_module(
self, path: QualifiedName, module: types.ModuleType
) -> Module | None:
result = Module(name=path[-1])
for name, member in inspect.getmembers(module):
for name, member in module.__dict__.items():
obj = self.handle_module_member(
QualifiedName([*path, Identifier(name)]), module, member
)
Expand Down Expand Up @@ -647,9 +647,7 @@ def parse_function_docstring(
# This syntax is not supported before Python 3.12.
return []
type_vars: list[str] = list(
filter(
bool, map(str.strip, (type_vars_group or "").split(","))
)
filter(bool, map(str.strip, (type_vars_group or "").split(",")))
)
args = self.call_with_local_types(
type_vars, lambda: self.parse_args_str(match.group("args"))
Expand Down
70 changes: 67 additions & 3 deletions pybind11_stubgen/printer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import dataclasses
import logging
import sys

from pybind11_stubgen.structs import (
Expand All @@ -24,14 +25,77 @@
Value,
)

log = logging.getLogger("pybind11_stubgen")


def indent_lines(lines: list[str], by=4) -> list[str]:
return [" " * by + line for line in lines]


def _topological_sort_classes(classes: list[Class]) -> list[Class]:
"""Sort classes so that base classes appear before derived classes.

Uses Kahn's algorithm. Ties are broken by input position for stability.
External bases (not in the current scope) are ignored.
"""
if not classes:
return classes

name_to_index = {c.name: i for i, c in enumerate(classes)}
name_to_class = {c.name: c for c in classes}

# Build adjacency list: base -> [derived, ...]
# and in-degree count for each class
children: dict[str, list[str]] = {c.name: [] for c in classes}
in_degree: dict[str, int] = {c.name: 0 for c in classes}

for c in classes:
for base in c.bases:
base_name = str(base[-1])
if base_name in name_to_class:
children[base_name].append(c.name)
in_degree[c.name] += 1

# Initialize queue with zero in-degree classes, sorted by input position
queue = sorted(
[name for name, deg in in_degree.items() if deg == 0],
key=lambda n: name_to_index[n],
)

result = []
while queue:
name = queue.pop(0)
result.append(name_to_class[name])
# Sort children by input position for stable ordering
for child in sorted(children[name], key=lambda n: name_to_index[n]):
in_degree[child] -= 1
if in_degree[child] == 0:
queue.append(child)
# Re-sort queue to maintain input-position priority
queue.sort(key=lambda n: name_to_index[n])

if len(result) < len(classes):
remaining = [c for c in classes if c.name not in {r.name for r in result}]
log.warning(
"Cycle detected in class inheritance involving: %s. "
"Appending in original order.",
[c.name for c in remaining],
)
result.extend(remaining)

return result


class Printer:
def __init__(self, invalid_expr_as_ellipses: bool):
def __init__(self, invalid_expr_as_ellipses: bool, sort_by: str = "definition"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest an explicit typehint to be more clear on the options.

Suggested change
def __init__(self, invalid_expr_as_ellipses: bool, sort_by: str = "definition"):
def __init__(self, invalid_expr_as_ellipses: bool, sort_by: Literal["definition", "topological"] = "definition"):

self.invalid_expr_as_ellipses = invalid_expr_as_ellipses
self.sort_by = sort_by

def _order_classes(self, classes: list[Class]) -> list[Class]:
if self.sort_by == "definition":
return classes
else: # "topological"
return _topological_sort_classes(classes)

def print_alias(self, alias: Alias) -> list[str]:
return [f"{alias.name} = {alias.origin}"]
Expand Down Expand Up @@ -90,7 +154,7 @@ def print_class_body(self, class_: Class) -> list[str]:
if class_.doc is not None:
result.extend(self.print_docstring(class_.doc))

for sub_class in sorted(class_.classes, key=lambda c: c.name):
for sub_class in self._order_classes(class_.classes):
result.extend(self.print_class(sub_class))

modifier_order: dict[Modifier, int] = {
Expand Down Expand Up @@ -232,7 +296,7 @@ def print_module(self, module: Module) -> list[str]:
for type_var in sorted(module.type_vars, key=lambda t: t.name):
result.extend(self.print_type_var(type_var))

for class_ in sorted(module.classes, key=lambda c: c.name):
for class_ in self._order_classes(module.classes):
result.extend(self.print_class(class_))

for func in sorted(module.functions, key=lambda f: f.name):
Expand Down
47 changes: 39 additions & 8 deletions tests/demo-lib/include/demo/Inheritance.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,46 @@
#pragma once
#include <string>

namespace demo{
namespace demo
{
// note: class stubs must not be sorted
// https://github.com/sizmailov/pybind11-stubgen/issues/231

struct Base {
struct Inner{};
std::string name;
};
struct MyBase {
struct Inner{};
std::string name;
};

struct Derived : Base {
int count;
};
struct Derived : MyBase {
int count;
};

// Cross-reference test (the "cyclic" case from issue #231 / PR #275):
// ParIterBase is a base class for ParIter.
// ParticleContainer references ParIter (via an alias).
// ParIter.__init__ takes a ParticleContainer& (annotation back-reference).
// This is NOT cyclic inheritance — just interleaved name usage.

struct ParIterBase {
int level;
};

struct ParticleContainer; // forward declaration

struct ParIter : ParIterBase {
ParticleContainer* container;
ParIter(ParticleContainer& pc, int level);
};

struct ParticleContainer {
std::string name;
void process(ParIter& it);
};

inline ParIter::ParIter(ParticleContainer& pc, int level)
: container(&pc), ParIterBase{level} {}

inline void ParticleContainer::process(ParIter& it) {
it.level += 1;
}
}
32 changes: 28 additions & 4 deletions tests/py-demo/bindings/src/modules/classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ void bind_classes_module(py::module&&m) {
}

{
py::class_<demo::Base> pyBase(m, "Base");
py::class_<demo::MyBase> pyMyBase(m, "MyBase");

pyBase.def_readwrite("name", &demo::Base::name);
pyMyBase.def_readwrite("name", &demo::MyBase::name);

py::class_<demo::Base::Inner>(pyBase, "Inner");
py::class_<demo::MyBase::Inner>(pyMyBase, "Inner");

py::class_<demo::Derived, demo::Base>(m, "Derived")
py::class_<demo::Derived, demo::MyBase>(m, "Derived")
.def_readwrite("count", &demo::Derived::count);

}
Expand All @@ -38,6 +38,30 @@ void bind_classes_module(py::module&&m) {
.def("g", &demo::Foo::Child::g);
}

// Cross-reference / "cyclic" test case (issue #231, PR #275):
// Registration order: ParticleContainer, then ParIter, then ParIterBase.
// ParticleContainer.Iterator is an alias to ParIter (cross-ref).
// ParIter inherits ParIterBase and takes ParticleContainer in __init__.
// The topological sort must put ParIterBase before ParIter;
// from __future__ import annotations handles the annotation back-refs.
{
auto pyParIterBase = py::class_<demo::ParIterBase>(m, "ParIterBase");
pyParIterBase.def_readwrite("level", &demo::ParIterBase::level);

auto pyParticleContainer = py::class_<demo::ParticleContainer>(m, "ParticleContainer");
pyParticleContainer.def_readwrite("name", &demo::ParticleContainer::name);

auto pyParIter = py::class_<demo::ParIter, demo::ParIterBase>(m, "ParIter");
pyParIter.def(py::init<demo::ParticleContainer&, int>(),
py::arg("particle_container"), py::arg("level"));

// Bind after ParIter is registered so pybind11 resolves the Python type
pyParticleContainer.def("process", &demo::ParticleContainer::process);

// Alias: ParticleContainer.Iterator = ParIter
pyParticleContainer.attr("Iterator") = pyParIter;
}

{
py::register_exception<demo::CppException>(m, "CppException");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ __all__: list[str] = [
"random",
]

class Color:
pass

class Dummy:
linalg = numpy.linalg

class Color:
pass

def foreign_enum_default(
color: typing.Any = demo._bindings.enum.ConsoleForegroundColor.Blue,
) -> None: ...
def func(arg0: int) -> int: ...

local_func_alias = func
local_type_alias = Color
local_func_alias = func
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,16 @@ from __future__ import annotations

import typing

__all__: list[str] = ["Base", "CppException", "Derived", "Foo", "Outer"]

class Base:
class Inner:
pass

name: str

class CppException(Exception):
pass

class Derived(Base):
count: int

class Foo:
class FooChild:
def __init__(self) -> None: ...
def g(self) -> None: ...

def __init__(self) -> None: ...
def f(self) -> None: ...
__all__: list[str] = [
"CppException",
"Derived",
"Foo",
"MyBase",
"Outer",
"ParIter",
"ParIterBase",
"ParticleContainer",
]

class Outer:
class Inner:
Expand Down Expand Up @@ -58,3 +47,34 @@ class Outer:
value: Outer.Inner.NestedEnum

inner: Outer.Inner

class MyBase:
class Inner:
pass

name: str

class Derived(MyBase):
count: int

class Foo:
class FooChild:
def __init__(self) -> None: ...
def g(self) -> None: ...

def __init__(self) -> None: ...
def f(self) -> None: ...

class ParIterBase:
level: int

class ParticleContainer:
name: str
Iterator = ParIter
def process(self, arg0: ParIter) -> None: ...

class ParIter(ParIterBase):
def __init__(self, particle_container: ParticleContainer, level: int) -> None: ...

class CppException(Exception):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ __all__: list[str] = [
"get_unbound_type",
]

class Enum:
class Unbound:
pass

class Unbound:
class Enum:
pass

def accept_unbound_enum(arg0: ...) -> int: ...
Expand Down
Loading
Loading