Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 69 additions & 4 deletions src/griffe_pydantic/_internal/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,49 @@
class PydanticExtension(Extension):
"""Griffe extension for Pydantic."""

def __init__(self, *, schema: bool = False) -> None:
def __init__(
self,
*,
schema: bool = False,
extra_bases: list[str] | None = None,
) -> None:
"""Initialize the extension.

Parameters:
schema: Whether to compute and store the JSON schema of models.
extra_bases: Additional base classes to detect as Pydantic models.
"""
super().__init__()
self._schema = schema
self._extra_bases = extra_bases or []
self._processed: set[str] = set()
self._recorded: list[tuple[ObjectNode, Class]] = []

def on_package(self, *, pkg: Module, **kwargs: Any) -> None: # noqa: ARG002
"""Detect models once the whole package is loaded."""
for node, cls in self._recorded:
self._processed.add(cls.canonical_path)
dynamic._process_class(node.obj, cls, processed=self._processed, schema=self._schema)
static._process_module(pkg, processed=self._processed, schema=self._schema)
dynamic._process_class(
node.obj,
cls,
processed=self._processed,
schema=self._schema,
)
self._recorded.clear()
static._process_module(
pkg,
processed=self._processed,
schema=self._schema,
extra_bases=self._extra_bases,
)

def on_class_instance(self, *, node: ast.AST | ObjectNode, cls: Class, **kwargs: Any) -> None: # noqa: ARG002
def on_class_instance(
self,
*,
node: ast.AST | ObjectNode,
cls: Class,
**kwargs: Any, # noqa: ARG002
) -> None:
"""Detect and prepare Pydantic models."""
# Prevent running during static analysis.
if isinstance(node, ast.AST):
Expand All @@ -52,5 +76,46 @@ def on_class_instance(self, *, node: ast.AST | ObjectNode, cls: Class, **kwargs:
_logger.warning("could not import pydantic - models will not be detected")
return

# Check if it's a standard Pydantic model
if issubclass(node.obj, pydantic.BaseModel):
self._recorded.append((node, cls))
return

# Check if it's a subclass of any extra base classes
for extra_base in self._extra_bases:
try:
# Import the extra base class
parts = extra_base.split(".")
module_name = ".".join(parts[:-1])
class_name = parts[-1]

if module_name:
import importlib # noqa: PLC0415

module = importlib.import_module(module_name)
base_class = getattr(module, class_name)
else:
# Handle case where only class name is provided (in current module)
base_class = globals().get(class_name)
if base_class is None:
continue

if base_class and issubclass(node.obj, base_class):
try:
# Verify that this extra base ultimately inherits from BaseModel
if issubclass(base_class, pydantic.BaseModel):
self._recorded.append((node, cls))
return
_logger.debug(f"Extra base class {extra_base} does not inherit from pydantic.BaseModel")
except TypeError:
# issubclass() can raise TypeError if base_class is not a class
_logger.debug(f"Extra base {extra_base} is not a valid class for issubclass check")
continue
except (ImportError, AttributeError):
# Skip if we can't import or resolve the extra base
_logger.debug(f"Could not resolve extra base class: {extra_base}")
continue
except TypeError:
# issubclass() with node.obj failed
_logger.debug(f"Could not check inheritance from extra base class: {extra_base}")
continue
46 changes: 38 additions & 8 deletions src/griffe_pydantic/_internal/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,28 @@
_logger = get_logger("griffe_pydantic")


def _inherits_pydantic(cls: Class) -> bool:
def _inherits_pydantic(cls: Class, extra_bases: list[str] | None = None) -> bool:
"""Tell whether a class inherits from a Pydantic model.

Parameters:
cls: A Griffe class.
extra_bases: Additional base classes to consider as Pydantic models.

Returns:
True/False.
"""
extra_bases = extra_bases or []

for base in cls.bases:
if isinstance(base, (ExprName, Expr)):
base = base.canonical_path # noqa: PLW2901
if base in {"pydantic.BaseModel", "pydantic.main.BaseModel"}:
return True
# Check if the base matches any of the extra bases
if base in extra_bases:
return True

return any(_inherits_pydantic(parent_class) for parent_class in cls.mro())
return any(_inherits_pydantic(parent_class, extra_bases) for parent_class in cls.mro())


def _pydantic_validator(func: Function) -> ExprCall | None:
Expand Down Expand Up @@ -123,7 +129,9 @@ def _process_attribute(attr: Attribute, cls: Class, *, processed: set[str]) -> N
try:
attr.docstring = Docstring(ast.literal_eval(docstring), parent=attr) # type: ignore[arg-type]
except ValueError:
_logger.debug(f"Could not parse description of field '{attr.path}' as literal, skipping")
_logger.debug(
f"Could not parse description of field '{attr.path}' as literal, skipping",
)


def _process_function(func: Function, cls: Class, *, processed: set[str]) -> None:
Expand All @@ -141,12 +149,18 @@ def _process_function(func: Function, cls: Class, *, processed: set[str]) -> Non
common._process_function(func, cls, fields)


def _process_class(cls: Class, *, processed: set[str], schema: bool = False) -> None:
def _process_class(
cls: Class,
*,
processed: set[str],
schema: bool = False,
extra_bases: list[str] | None = None,
) -> None:
"""Finalize the Pydantic model data."""
if cls.canonical_path in processed:
return

if not _inherits_pydantic(cls):
if not _inherits_pydantic(cls, extra_bases):
return

processed.add(cls.canonical_path)
Expand Down Expand Up @@ -174,14 +188,20 @@ def _process_class(cls: Class, *, processed: set[str], schema: bool = False) ->
elif kind is Kind.FUNCTION:
_process_function(member, cls, processed=processed) # type: ignore[arg-type]
elif kind is Kind.CLASS:
_process_class(member, processed=processed, schema=schema) # type: ignore[arg-type]
_process_class(
member, # type: ignore[arg-type]
processed=processed,
schema=schema,
extra_bases=extra_bases,
)


def _process_module(
mod: Module,
*,
processed: set[str],
schema: bool = False,
extra_bases: list[str] | None = None,
) -> None:
"""Handle Pydantic models in a module."""
if mod.canonical_path in processed:
Expand All @@ -191,7 +211,17 @@ def _process_module(
for cls in mod.classes.values():
# Don't process aliases, real classes will be processed at some point anyway.
if not cls.is_alias:
_process_class(cls, processed=processed, schema=schema)
_process_class(
cls,
processed=processed,
schema=schema,
extra_bases=extra_bases,
)

for submodule in mod.modules.values():
_process_module(submodule, processed=processed, schema=schema)
_process_module(
submodule,
processed=processed,
schema=schema,
extra_bases=extra_bases,
)
34 changes: 27 additions & 7 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,31 @@ def _yield_public_objects(


@pytest.fixture(name="modulelevel_internal_objects", scope="module")
def _fixture_modulelevel_internal_objects(internal_api: griffe.Module) -> list[griffe.Object | griffe.Alias]:
def _fixture_modulelevel_internal_objects(
internal_api: griffe.Module,
) -> list[griffe.Object | griffe.Alias]:
return list(_yield_public_objects(internal_api, modulelevel=True))


@pytest.fixture(name="internal_objects", scope="module")
def _fixture_internal_objects(internal_api: griffe.Module) -> list[griffe.Object | griffe.Alias]:
def _fixture_internal_objects(
internal_api: griffe.Module,
) -> list[griffe.Object | griffe.Alias]:
return list(_yield_public_objects(internal_api, modulelevel=False, special=True))


@pytest.fixture(name="public_objects", scope="module")
def _fixture_public_objects(public_api: griffe.Module) -> list[griffe.Object | griffe.Alias]:
return list(_yield_public_objects(public_api, modulelevel=False, inherited=True, special=True))
def _fixture_public_objects(
public_api: griffe.Module,
) -> list[griffe.Object | griffe.Alias]:
return list(
_yield_public_objects(
public_api,
modulelevel=False,
inherited=True,
special=True,
),
)


@pytest.fixture(name="inventory", scope="module")
Expand All @@ -96,7 +109,9 @@ def _fixture_inventory() -> Inventory:
return Inventory.parse_sphinx(file)


def test_exposed_objects(modulelevel_internal_objects: list[griffe.Object | griffe.Alias]) -> None:
def test_exposed_objects(
modulelevel_internal_objects: list[griffe.Object | griffe.Alias],
) -> None:
"""All public objects in the internal API are exposed under `griffe_pydantic`."""
not_exposed = [
obj.path
Expand All @@ -106,7 +121,9 @@ def test_exposed_objects(modulelevel_internal_objects: list[griffe.Object | grif
assert not not_exposed, "Objects not exposed:\n" + "\n".join(sorted(not_exposed))


def test_unique_names(modulelevel_internal_objects: list[griffe.Object | griffe.Alias]) -> None:
def test_unique_names(
modulelevel_internal_objects: list[griffe.Object | griffe.Alias],
) -> None:
"""All internal objects have unique names."""
names_to_paths = defaultdict(list)
for obj in modulelevel_internal_objects:
Expand All @@ -133,7 +150,10 @@ def _public_path(obj: griffe.Object | griffe.Alias) -> bool:
)


def test_api_matches_inventory(inventory: Inventory, public_objects: list[griffe.Object | griffe.Alias]) -> None:
def test_api_matches_inventory(
inventory: Inventory,
public_objects: list[griffe.Object | griffe.Alias],
) -> None:
"""All public objects are added to the inventory."""
ignore_names = {"__getattr__", "__init__", "__repr__", "__str__", "__post_init__"}
not_in_inventory = [
Expand Down
24 changes: 24 additions & 0 deletions tests/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,30 @@ class Model(Base):
assert "pydantic-field" not in package["Model.a"].labels


def test_extra_bases() -> None:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it's a good test since both classes are implemented in the same package. Should we split them to separate packages? or what's the best way to test this?

"""Test extra_bases functionality with custom base classes."""
code = """
from pydantic import BaseModel, Field

class CustomBase(BaseModel):
'''A custom base model.'''
pass

class CustomModel(CustomBase):
'''A model inheriting from custom base.'''
field1: str = Field(..., description="Test field")
"""
with temporary_visited_package(
"package",
modules={"__init__.py": code},
extensions=Extensions(PydanticExtension(schema=False, extra_bases=["package.CustomBase"])),
) as package:
# Both CustomBase and CustomModel should be detected
assert "pydantic-model" in package["CustomBase"].labels
assert "pydantic-model" in package["CustomModel"].labels
assert "pydantic-field" in package["CustomModel.field1"].labels


def test_process_non_model_base_class_fields() -> None:
"""Fields in a non-model base class must be processed."""
code = """
Expand Down
Loading