diff --git a/src/griffe_pydantic/_internal/extension.py b/src/griffe_pydantic/_internal/extension.py index 8f14241..0b5594f 100644 --- a/src/griffe_pydantic/_internal/extension.py +++ b/src/griffe_pydantic/_internal/extension.py @@ -22,14 +22,21 @@ class PydanticExtension(Extension): """Griffe extension for Pydantic.""" - def __init__(self, *, schema: bool = False) -> None: + def __init__( + self, + *, + schema: bool = False, + extra_bases: list[str] | None = None, + ) -> None: """Initialize the extension. Parameters: schema: Whether to compute and store the JSON schema of models. + extra_bases: Additional base classes to detect as Pydantic models. """ super().__init__() self._schema = schema + self._extra_bases = extra_bases or [] self._processed: set[str] = set() self._recorded: list[tuple[ObjectNode, Class]] = [] @@ -37,10 +44,27 @@ def on_package(self, *, pkg: Module, **kwargs: Any) -> None: # noqa: ARG002 """Detect models once the whole package is loaded.""" for node, cls in self._recorded: self._processed.add(cls.canonical_path) - dynamic._process_class(node.obj, cls, processed=self._processed, schema=self._schema) - static._process_module(pkg, processed=self._processed, schema=self._schema) + dynamic._process_class( + node.obj, + cls, + processed=self._processed, + schema=self._schema, + ) + self._recorded.clear() + static._process_module( + pkg, + processed=self._processed, + schema=self._schema, + extra_bases=self._extra_bases, + ) - def on_class_instance(self, *, node: ast.AST | ObjectNode, cls: Class, **kwargs: Any) -> None: # noqa: ARG002 + def on_class_instance( + self, + *, + node: ast.AST | ObjectNode, + cls: Class, + **kwargs: Any, # noqa: ARG002 + ) -> None: """Detect and prepare Pydantic models.""" # Prevent running during static analysis. if isinstance(node, ast.AST): @@ -52,5 +76,46 @@ def on_class_instance(self, *, node: ast.AST | ObjectNode, cls: Class, **kwargs: _logger.warning("could not import pydantic - models will not be detected") return + # Check if it's a standard Pydantic model if issubclass(node.obj, pydantic.BaseModel): self._recorded.append((node, cls)) + return + + # Check if it's a subclass of any extra base classes + for extra_base in self._extra_bases: + try: + # Import the extra base class + parts = extra_base.split(".") + module_name = ".".join(parts[:-1]) + class_name = parts[-1] + + if module_name: + import importlib # noqa: PLC0415 + + module = importlib.import_module(module_name) + base_class = getattr(module, class_name) + else: + # Handle case where only class name is provided (in current module) + base_class = globals().get(class_name) + if base_class is None: + continue + + if base_class and issubclass(node.obj, base_class): + try: + # Verify that this extra base ultimately inherits from BaseModel + if issubclass(base_class, pydantic.BaseModel): + self._recorded.append((node, cls)) + return + _logger.debug(f"Extra base class {extra_base} does not inherit from pydantic.BaseModel") + except TypeError: + # issubclass() can raise TypeError if base_class is not a class + _logger.debug(f"Extra base {extra_base} is not a valid class for issubclass check") + continue + except (ImportError, AttributeError): + # Skip if we can't import or resolve the extra base + _logger.debug(f"Could not resolve extra base class: {extra_base}") + continue + except TypeError: + # issubclass() with node.obj failed + _logger.debug(f"Could not check inheritance from extra base class: {extra_base}") + continue diff --git a/src/griffe_pydantic/_internal/static.py b/src/griffe_pydantic/_internal/static.py index 73993a0..aec1ee8 100644 --- a/src/griffe_pydantic/_internal/static.py +++ b/src/griffe_pydantic/_internal/static.py @@ -29,22 +29,28 @@ _logger = get_logger("griffe_pydantic") -def _inherits_pydantic(cls: Class) -> bool: +def _inherits_pydantic(cls: Class, extra_bases: list[str] | None = None) -> bool: """Tell whether a class inherits from a Pydantic model. Parameters: cls: A Griffe class. + extra_bases: Additional base classes to consider as Pydantic models. Returns: True/False. """ + extra_bases = extra_bases or [] + for base in cls.bases: if isinstance(base, (ExprName, Expr)): base = base.canonical_path # noqa: PLW2901 if base in {"pydantic.BaseModel", "pydantic.main.BaseModel"}: return True + # Check if the base matches any of the extra bases + if base in extra_bases: + return True - return any(_inherits_pydantic(parent_class) for parent_class in cls.mro()) + return any(_inherits_pydantic(parent_class, extra_bases) for parent_class in cls.mro()) def _pydantic_validator(func: Function) -> ExprCall | None: @@ -123,7 +129,9 @@ def _process_attribute(attr: Attribute, cls: Class, *, processed: set[str]) -> N try: attr.docstring = Docstring(ast.literal_eval(docstring), parent=attr) # type: ignore[arg-type] except ValueError: - _logger.debug(f"Could not parse description of field '{attr.path}' as literal, skipping") + _logger.debug( + f"Could not parse description of field '{attr.path}' as literal, skipping", + ) def _process_function(func: Function, cls: Class, *, processed: set[str]) -> None: @@ -141,12 +149,18 @@ def _process_function(func: Function, cls: Class, *, processed: set[str]) -> Non common._process_function(func, cls, fields) -def _process_class(cls: Class, *, processed: set[str], schema: bool = False) -> None: +def _process_class( + cls: Class, + *, + processed: set[str], + schema: bool = False, + extra_bases: list[str] | None = None, +) -> None: """Finalize the Pydantic model data.""" if cls.canonical_path in processed: return - if not _inherits_pydantic(cls): + if not _inherits_pydantic(cls, extra_bases): return processed.add(cls.canonical_path) @@ -174,7 +188,12 @@ def _process_class(cls: Class, *, processed: set[str], schema: bool = False) -> elif kind is Kind.FUNCTION: _process_function(member, cls, processed=processed) # type: ignore[arg-type] elif kind is Kind.CLASS: - _process_class(member, processed=processed, schema=schema) # type: ignore[arg-type] + _process_class( + member, # type: ignore[arg-type] + processed=processed, + schema=schema, + extra_bases=extra_bases, + ) def _process_module( @@ -182,6 +201,7 @@ def _process_module( *, processed: set[str], schema: bool = False, + extra_bases: list[str] | None = None, ) -> None: """Handle Pydantic models in a module.""" if mod.canonical_path in processed: @@ -191,7 +211,17 @@ def _process_module( for cls in mod.classes.values(): # Don't process aliases, real classes will be processed at some point anyway. if not cls.is_alias: - _process_class(cls, processed=processed, schema=schema) + _process_class( + cls, + processed=processed, + schema=schema, + extra_bases=extra_bases, + ) for submodule in mod.modules.values(): - _process_module(submodule, processed=processed, schema=schema) + _process_module( + submodule, + processed=processed, + schema=schema, + extra_bases=extra_bases, + ) diff --git a/tests/test_api.py b/tests/test_api.py index e1ef1b8..e4d4be1 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -73,18 +73,31 @@ def _yield_public_objects( @pytest.fixture(name="modulelevel_internal_objects", scope="module") -def _fixture_modulelevel_internal_objects(internal_api: griffe.Module) -> list[griffe.Object | griffe.Alias]: +def _fixture_modulelevel_internal_objects( + internal_api: griffe.Module, +) -> list[griffe.Object | griffe.Alias]: return list(_yield_public_objects(internal_api, modulelevel=True)) @pytest.fixture(name="internal_objects", scope="module") -def _fixture_internal_objects(internal_api: griffe.Module) -> list[griffe.Object | griffe.Alias]: +def _fixture_internal_objects( + internal_api: griffe.Module, +) -> list[griffe.Object | griffe.Alias]: return list(_yield_public_objects(internal_api, modulelevel=False, special=True)) @pytest.fixture(name="public_objects", scope="module") -def _fixture_public_objects(public_api: griffe.Module) -> list[griffe.Object | griffe.Alias]: - return list(_yield_public_objects(public_api, modulelevel=False, inherited=True, special=True)) +def _fixture_public_objects( + public_api: griffe.Module, +) -> list[griffe.Object | griffe.Alias]: + return list( + _yield_public_objects( + public_api, + modulelevel=False, + inherited=True, + special=True, + ), + ) @pytest.fixture(name="inventory", scope="module") @@ -96,7 +109,9 @@ def _fixture_inventory() -> Inventory: return Inventory.parse_sphinx(file) -def test_exposed_objects(modulelevel_internal_objects: list[griffe.Object | griffe.Alias]) -> None: +def test_exposed_objects( + modulelevel_internal_objects: list[griffe.Object | griffe.Alias], +) -> None: """All public objects in the internal API are exposed under `griffe_pydantic`.""" not_exposed = [ obj.path @@ -106,7 +121,9 @@ def test_exposed_objects(modulelevel_internal_objects: list[griffe.Object | grif assert not not_exposed, "Objects not exposed:\n" + "\n".join(sorted(not_exposed)) -def test_unique_names(modulelevel_internal_objects: list[griffe.Object | griffe.Alias]) -> None: +def test_unique_names( + modulelevel_internal_objects: list[griffe.Object | griffe.Alias], +) -> None: """All internal objects have unique names.""" names_to_paths = defaultdict(list) for obj in modulelevel_internal_objects: @@ -133,7 +150,10 @@ def _public_path(obj: griffe.Object | griffe.Alias) -> bool: ) -def test_api_matches_inventory(inventory: Inventory, public_objects: list[griffe.Object | griffe.Alias]) -> None: +def test_api_matches_inventory( + inventory: Inventory, + public_objects: list[griffe.Object | griffe.Alias], +) -> None: """All public objects are added to the inventory.""" ignore_names = {"__getattr__", "__init__", "__repr__", "__str__", "__post_init__"} not_in_inventory = [ diff --git a/tests/test_extension.py b/tests/test_extension.py index c540748..c86e87e 100644 --- a/tests/test_extension.py +++ b/tests/test_extension.py @@ -216,6 +216,30 @@ class Model(Base): assert "pydantic-field" not in package["Model.a"].labels +def test_extra_bases() -> None: + """Test extra_bases functionality with custom base classes.""" + code = """ + from pydantic import BaseModel, Field + + class CustomBase(BaseModel): + '''A custom base model.''' + pass + + class CustomModel(CustomBase): + '''A model inheriting from custom base.''' + field1: str = Field(..., description="Test field") + """ + with temporary_visited_package( + "package", + modules={"__init__.py": code}, + extensions=Extensions(PydanticExtension(schema=False, extra_bases=["package.CustomBase"])), + ) as package: + # Both CustomBase and CustomModel should be detected + assert "pydantic-model" in package["CustomBase"].labels + assert "pydantic-model" in package["CustomModel"].labels + assert "pydantic-field" in package["CustomModel.field1"].labels + + def test_process_non_model_base_class_fields() -> None: """Fields in a non-model base class must be processed.""" code = """