Skip to content

Commit 7d73e90

Browse files
committed
[feat] introduce extra_bases configuration option to support third-party Pydantic base models
1 parent 3c432f5 commit 7d73e90

File tree

5 files changed

+187
-20
lines changed

5 files changed

+187
-20
lines changed

README.md

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ import griffe
3131
griffe.load(
3232
"mypackage",
3333
extensions=griffe.load_extensions(
34-
[{"griffe_pydantic": {"schema": True}}]
34+
[{"griffe_pydantic": {
35+
"schema": True,
36+
"extra_bases": ["sqlmodel.SQLModel", "my_package.CustomBaseModel"]
37+
}}]
3538
)
3639
)
3740
```
@@ -49,7 +52,32 @@ plugins:
4952
extensions:
5053
- griffe_pydantic:
5154
schema: true
55+
extra_bases:
56+
- sqlmodel.SQLModel
57+
- my_package.CustomBaseModel
5258
```
5359
5460
5561
See [MkDocs usage in Griffe's documentation](https://mkdocstrings.github.io/griffe/extensions/#in-mkdocs).
62+
63+
## Configuration Options
64+
65+
### `extra_bases`
66+
67+
By default, the extension detects classes that directly inherit from `pydantic.BaseModel`. You can configure additional base classes to be detected using the `extra_bases` option:
68+
69+
```yaml
70+
extensions:
71+
- griffe_pydantic:
72+
extra_bases:
73+
- sqlmodel.SQLModel
74+
- my_package.CustomBaseModel
75+
```
76+
77+
This is useful when working with libraries that provide their own base classes which inherit from `pydantic.BaseModel`, such as:
78+
79+
- `SQLModel` from the SQLModel library
80+
- Custom base classes in your own codebase
81+
- Third-party libraries that extend Pydantic models
82+
83+
The extension will detect any class that inherits from these base classes, as long as the base class ultimately inherits from `pydantic.BaseModel`. This enables proper documentation generation for models that use custom or third-party Pydantic-based base classes.

src/griffe_pydantic/_internal/extension.py

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,49 @@
2222
class PydanticExtension(Extension):
2323
"""Griffe extension for Pydantic."""
2424

25-
def __init__(self, *, schema: bool = False) -> None:
25+
def __init__(
26+
self,
27+
*,
28+
schema: bool = False,
29+
extra_bases: list[str] | None = None,
30+
) -> None:
2631
"""Initialize the extension.
2732
2833
Parameters:
2934
schema: Whether to compute and store the JSON schema of models.
35+
extra_bases: Additional base classes to detect as Pydantic models.
3036
"""
3137
super().__init__()
3238
self._schema = schema
39+
self._extra_bases = extra_bases or []
3340
self._processed: set[str] = set()
3441
self._recorded: list[tuple[ObjectNode, Class]] = []
3542

3643
def on_package(self, *, pkg: Module, **kwargs: Any) -> None: # noqa: ARG002
3744
"""Detect models once the whole package is loaded."""
3845
for node, cls in self._recorded:
3946
self._processed.add(cls.canonical_path)
40-
dynamic._process_class(node.obj, cls, processed=self._processed, schema=self._schema)
41-
static._process_module(pkg, processed=self._processed, schema=self._schema)
47+
dynamic._process_class(
48+
node.obj,
49+
cls,
50+
processed=self._processed,
51+
schema=self._schema,
52+
)
53+
self._recorded.clear()
54+
static._process_module(
55+
pkg,
56+
processed=self._processed,
57+
schema=self._schema,
58+
extra_bases=self._extra_bases,
59+
)
4260

43-
def on_class_instance(self, *, node: ast.AST | ObjectNode, cls: Class, **kwargs: Any) -> None: # noqa: ARG002
61+
def on_class_instance(
62+
self,
63+
*,
64+
node: ast.AST | ObjectNode,
65+
cls: Class,
66+
**kwargs: Any, # noqa: ARG002
67+
) -> None:
4468
"""Detect and prepare Pydantic models."""
4569
# Prevent running during static analysis.
4670
if isinstance(node, ast.AST):
@@ -52,5 +76,46 @@ def on_class_instance(self, *, node: ast.AST | ObjectNode, cls: Class, **kwargs:
5276
_logger.warning("could not import pydantic - models will not be detected")
5377
return
5478

79+
# Check if it's a standard Pydantic model
5580
if issubclass(node.obj, pydantic.BaseModel):
5681
self._recorded.append((node, cls))
82+
return
83+
84+
# Check if it's a subclass of any extra base classes
85+
for extra_base in self._extra_bases:
86+
try:
87+
# Import the extra base class
88+
parts = extra_base.split(".")
89+
module_name = ".".join(parts[:-1])
90+
class_name = parts[-1]
91+
92+
if module_name:
93+
import importlib # noqa: PLC0415
94+
95+
module = importlib.import_module(module_name)
96+
base_class = getattr(module, class_name)
97+
else:
98+
# Handle case where only class name is provided (in current module)
99+
base_class = globals().get(class_name)
100+
if base_class is None:
101+
continue
102+
103+
if base_class and issubclass(node.obj, base_class):
104+
try:
105+
# Verify that this extra base ultimately inherits from BaseModel
106+
if issubclass(base_class, pydantic.BaseModel):
107+
self._recorded.append((node, cls))
108+
return
109+
_logger.debug(f"Extra base class {extra_base} does not inherit from pydantic.BaseModel")
110+
except TypeError:
111+
# issubclass() can raise TypeError if base_class is not a class
112+
_logger.debug(f"Extra base {extra_base} is not a valid class for issubclass check")
113+
continue
114+
except (ImportError, AttributeError):
115+
# Skip if we can't import or resolve the extra base
116+
_logger.debug(f"Could not resolve extra base class: {extra_base}")
117+
continue
118+
except TypeError:
119+
# issubclass() with node.obj failed
120+
_logger.debug(f"Could not check inheritance from extra base class: {extra_base}")
121+
continue

src/griffe_pydantic/_internal/static.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,28 @@
2929
_logger = get_logger("griffe_pydantic")
3030

3131

32-
def _inherits_pydantic(cls: Class) -> bool:
32+
def _inherits_pydantic(cls: Class, extra_bases: list[str] | None = None) -> bool:
3333
"""Tell whether a class inherits from a Pydantic model.
3434
3535
Parameters:
3636
cls: A Griffe class.
37+
extra_bases: Additional base classes to consider as Pydantic models.
3738
3839
Returns:
3940
True/False.
4041
"""
42+
extra_bases = extra_bases or []
43+
4144
for base in cls.bases:
4245
if isinstance(base, (ExprName, Expr)):
4346
base = base.canonical_path # noqa: PLW2901
4447
if base in {"pydantic.BaseModel", "pydantic.main.BaseModel"}:
4548
return True
49+
# Check if the base matches any of the extra bases
50+
if base in extra_bases:
51+
return True
4652

47-
return any(_inherits_pydantic(parent_class) for parent_class in cls.mro())
53+
return any(_inherits_pydantic(parent_class, extra_bases) for parent_class in cls.mro())
4854

4955

5056
def _pydantic_validator(func: Function) -> ExprCall | None:
@@ -123,7 +129,9 @@ def _process_attribute(attr: Attribute, cls: Class, *, processed: set[str]) -> N
123129
try:
124130
attr.docstring = Docstring(ast.literal_eval(docstring), parent=attr) # type: ignore[arg-type]
125131
except ValueError:
126-
_logger.debug(f"Could not parse description of field '{attr.path}' as literal, skipping")
132+
_logger.debug(
133+
f"Could not parse description of field '{attr.path}' as literal, skipping",
134+
)
127135

128136

129137
def _process_function(func: Function, cls: Class, *, processed: set[str]) -> None:
@@ -141,12 +149,18 @@ def _process_function(func: Function, cls: Class, *, processed: set[str]) -> Non
141149
common._process_function(func, cls, fields)
142150

143151

144-
def _process_class(cls: Class, *, processed: set[str], schema: bool = False) -> None:
152+
def _process_class(
153+
cls: Class,
154+
*,
155+
processed: set[str],
156+
schema: bool = False,
157+
extra_bases: list[str] | None = None,
158+
) -> None:
145159
"""Finalize the Pydantic model data."""
146160
if cls.canonical_path in processed:
147161
return
148162

149-
if not _inherits_pydantic(cls):
163+
if not _inherits_pydantic(cls, extra_bases):
150164
return
151165

152166
processed.add(cls.canonical_path)
@@ -174,14 +188,20 @@ def _process_class(cls: Class, *, processed: set[str], schema: bool = False) ->
174188
elif kind is Kind.FUNCTION:
175189
_process_function(member, cls, processed=processed) # type: ignore[arg-type]
176190
elif kind is Kind.CLASS:
177-
_process_class(member, processed=processed, schema=schema) # type: ignore[arg-type]
191+
_process_class(
192+
member, # type: ignore[arg-type]
193+
processed=processed,
194+
schema=schema,
195+
extra_bases=extra_bases,
196+
)
178197

179198

180199
def _process_module(
181200
mod: Module,
182201
*,
183202
processed: set[str],
184203
schema: bool = False,
204+
extra_bases: list[str] | None = None,
185205
) -> None:
186206
"""Handle Pydantic models in a module."""
187207
if mod.canonical_path in processed:
@@ -191,7 +211,17 @@ def _process_module(
191211
for cls in mod.classes.values():
192212
# Don't process aliases, real classes will be processed at some point anyway.
193213
if not cls.is_alias:
194-
_process_class(cls, processed=processed, schema=schema)
214+
_process_class(
215+
cls,
216+
processed=processed,
217+
schema=schema,
218+
extra_bases=extra_bases,
219+
)
195220

196221
for submodule in mod.modules.values():
197-
_process_module(submodule, processed=processed, schema=schema)
222+
_process_module(
223+
submodule,
224+
processed=processed,
225+
schema=schema,
226+
extra_bases=extra_bases,
227+
)

tests/test_api.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,18 +73,31 @@ def _yield_public_objects(
7373

7474

7575
@pytest.fixture(name="modulelevel_internal_objects", scope="module")
76-
def _fixture_modulelevel_internal_objects(internal_api: griffe.Module) -> list[griffe.Object | griffe.Alias]:
76+
def _fixture_modulelevel_internal_objects(
77+
internal_api: griffe.Module,
78+
) -> list[griffe.Object | griffe.Alias]:
7779
return list(_yield_public_objects(internal_api, modulelevel=True))
7880

7981

8082
@pytest.fixture(name="internal_objects", scope="module")
81-
def _fixture_internal_objects(internal_api: griffe.Module) -> list[griffe.Object | griffe.Alias]:
83+
def _fixture_internal_objects(
84+
internal_api: griffe.Module,
85+
) -> list[griffe.Object | griffe.Alias]:
8286
return list(_yield_public_objects(internal_api, modulelevel=False, special=True))
8387

8488

8589
@pytest.fixture(name="public_objects", scope="module")
86-
def _fixture_public_objects(public_api: griffe.Module) -> list[griffe.Object | griffe.Alias]:
87-
return list(_yield_public_objects(public_api, modulelevel=False, inherited=True, special=True))
90+
def _fixture_public_objects(
91+
public_api: griffe.Module,
92+
) -> list[griffe.Object | griffe.Alias]:
93+
return list(
94+
_yield_public_objects(
95+
public_api,
96+
modulelevel=False,
97+
inherited=True,
98+
special=True,
99+
),
100+
)
88101

89102

90103
@pytest.fixture(name="inventory", scope="module")
@@ -96,7 +109,9 @@ def _fixture_inventory() -> Inventory:
96109
return Inventory.parse_sphinx(file)
97110

98111

99-
def test_exposed_objects(modulelevel_internal_objects: list[griffe.Object | griffe.Alias]) -> None:
112+
def test_exposed_objects(
113+
modulelevel_internal_objects: list[griffe.Object | griffe.Alias],
114+
) -> None:
100115
"""All public objects in the internal API are exposed under `griffe_pydantic`."""
101116
not_exposed = [
102117
obj.path
@@ -106,7 +121,9 @@ def test_exposed_objects(modulelevel_internal_objects: list[griffe.Object | grif
106121
assert not not_exposed, "Objects not exposed:\n" + "\n".join(sorted(not_exposed))
107122

108123

109-
def test_unique_names(modulelevel_internal_objects: list[griffe.Object | griffe.Alias]) -> None:
124+
def test_unique_names(
125+
modulelevel_internal_objects: list[griffe.Object | griffe.Alias],
126+
) -> None:
110127
"""All internal objects have unique names."""
111128
names_to_paths = defaultdict(list)
112129
for obj in modulelevel_internal_objects:
@@ -133,7 +150,10 @@ def _public_path(obj: griffe.Object | griffe.Alias) -> bool:
133150
)
134151

135152

136-
def test_api_matches_inventory(inventory: Inventory, public_objects: list[griffe.Object | griffe.Alias]) -> None:
153+
def test_api_matches_inventory(
154+
inventory: Inventory,
155+
public_objects: list[griffe.Object | griffe.Alias],
156+
) -> None:
137157
"""All public objects are added to the inventory."""
138158
ignore_names = {"__getattr__", "__init__", "__repr__", "__str__", "__post_init__"}
139159
not_in_inventory = [

tests/test_extension.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,30 @@ class Model(Base):
216216
assert "pydantic-field" not in package["Model.a"].labels
217217

218218

219+
def test_extra_bases() -> None:
220+
"""Test extra_bases functionality with custom base classes."""
221+
code = """
222+
from pydantic import BaseModel, Field
223+
224+
class CustomBase(BaseModel):
225+
'''A custom base model.'''
226+
pass
227+
228+
class CustomModel(CustomBase):
229+
'''A model inheriting from custom base.'''
230+
field1: str = Field(..., description="Test field")
231+
"""
232+
with temporary_visited_package(
233+
"package",
234+
modules={"__init__.py": code},
235+
extensions=Extensions(PydanticExtension(schema=False, extra_bases=["package.CustomBase"])),
236+
) as package:
237+
# Both CustomBase and CustomModel should be detected
238+
assert "pydantic-model" in package["CustomBase"].labels
239+
assert "pydantic-model" in package["CustomModel"].labels
240+
assert "pydantic-field" in package["CustomModel.field1"].labels
241+
242+
219243
def test_process_non_model_base_class_fields() -> None:
220244
"""Fields in a non-model base class must be processed."""
221245
code = """

0 commit comments

Comments
 (0)