diff --git a/pybind11_stubgen/printer.py b/pybind11_stubgen/printer.py index dfe802d..ea97e4c 100644 --- a/pybind11_stubgen/printer.py +++ b/pybind11_stubgen/printer.py @@ -231,9 +231,53 @@ def print_module(self, module: Module) -> list[str]: for type_var in sorted(module.type_vars, key=lambda t: t.name): result.extend(self.print_type_var(type_var)) + class classWriter: + def __init__(self, result, parent, sortedClasses) -> None: + self.sortedClasses = sortedClasses + self.setClassNames = set([c.name for c in sortedClasses]) + self.writtenClasses = set() + self.delayed = [] + self.result = result + self.parent = parent + + def canWriteClass(self, class_: Class) -> bool: + for base in class_.bases: + baseName = str(base) + if baseName in self.setClassNames and baseName not in self.writtenClasses: + return False + return True + + def writeClass(self, class_: Class) -> None: + self.result.extend(self.parent.print_class(class_)) + self.writtenClasses.add(class_.name) + + def writeFirstDelayedReady(self) -> bool: + for class_ in self.delayed: + if self.canWriteClass(class_): + self.writeClass(class_) + self.delayed.remove(class_) + return True + return False + + def writeDelayedClasses(self) -> None: + while self.writeFirstDelayedReady(): + pass + + def writeAll(self) -> None: + for class_ in self.sortedClasses: + self.writeDelayedClasses() + if self.canWriteClass(class_): + self.writeClass(class_) + else: + self.delayed.append(class_) + self.writeDelayedClasses() + if self.delayed: + print(f"Warning: class ordering failed for the following classes {[c.name for c in self.delayed]}") + for class_ in self.delayed: + self.writeClass(class_) + + classWriter(result, self, sorted(module.classes, key=lambda c: c.name)).writeAll() - for class_ in sorted(module.classes, key=lambda c: c.name): - result.extend(self.print_class(class_)) for func in sorted(module.functions, key=lambda f: f.name): result.extend(self.print_function(func))