Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/psygnal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"EventedModel",
"get_evented_namespace",
"is_evented",
"PSYGNAL_METADATA",
"Signal",
"SignalGroup",
"SignalGroupDescriptor",
Expand All @@ -48,6 +49,7 @@
stacklevel=2,
)

from ._dataclass_utils import PSYGNAL_METADATA
from ._evented_decorator import evented
from ._exceptions import EmitLoopError
from ._group import EmissionInfo, SignalGroup
Expand Down
271 changes: 263 additions & 8 deletions src/psygnal/_dataclass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,30 @@
import dataclasses
import sys
import types
from typing import TYPE_CHECKING, Any, Iterator, List, Protocol, cast, overload
from dataclasses import dataclass, fields
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterator,
List,
Mapping,
Protocol,
cast,
overload,
)

if TYPE_CHECKING:
from dataclasses import Field

import attrs
import msgspec
from pydantic import BaseModel
from typing_extensions import TypeGuard # py310
from typing_extensions import TypeAlias, TypeGuard # py310

EqOperator: TypeAlias = Callable[[Any, Any], bool]

PSYGNAL_METADATA = "__psygnal_metadata"


class _DataclassParams(Protocol):
Expand All @@ -29,12 +46,11 @@ class AttrsType:
__attrs_attrs__: tuple[attrs.Attribute, ...]


_DATACLASS_PARAMS = "__dataclass_params__"
KW_ONLY = object()
with contextlib.suppress(ImportError):
from dataclasses import _DATACLASS_PARAMS # type: ignore
from dataclasses import KW_ONLY # py310
_DATACLASS_PARAMS = "__dataclass_params__"
_DATACLASS_FIELDS = "__dataclass_fields__"
with contextlib.suppress(ImportError):
from dataclasses import _DATACLASS_FIELDS # type: ignore


class DataClassType:
Expand Down Expand Up @@ -171,8 +187,8 @@ def iter_fields(
yield field_name, p_field.annotation
else:
for p_field in cls.__fields__.values(): # type: ignore [attr-defined]
if p_field.field_info.allow_mutation or not exclude_frozen: # type: ignore
yield p_field.name, p_field.outer_type_ # type: ignore
if p_field.field_info.allow_mutation or not exclude_frozen: # type: ignore [attr-defined]
yield p_field.name, p_field.outer_type_ # type: ignore [attr-defined]
return

if (attrs_fields := getattr(cls, "__attrs_attrs__", None)) is not None:
Expand All @@ -185,3 +201,242 @@ def iter_fields(
type_ = cls.__annotations__.get(m_field, None)
yield m_field, type_
return


@dataclass
class FieldOptions:
name: str
type_: type | None = None
# set KW_ONLY value for compatibility with python < 3.10
_: KW_ONLY = KW_ONLY # type: ignore [valid-type]
alias: str | None = None
skip: bool | None = None
eq: EqOperator | None = None
disable_setattr: bool | None = None


def is_kw_only(f: Field) -> bool:
if hasattr(f, "kw_only"):
return cast(bool, f.kw_only)
# for python < 3.10
if f.name not in ["name", "type_"]:
return True
return False


def sanitize_field_options_dict(d: Mapping) -> dict[str, Any]:
field_options_kws = [f.name for f in fields(FieldOptions) if is_kw_only(f)]
return {k: v for k, v in d.items() if k in field_options_kws}


def get_msgspec_metadata(
cls: type[msgspec.Struct],
m_field: str,
) -> tuple[type | None, dict[str, Any]]:
# Look for type in cls and super classes
type_: type | None = None
for super_cls in cls.__mro__:
if not hasattr(super_cls, "__annotations__"):
continue
type_ = super_cls.__annotations__.get(m_field, None)
if type_ is not None:
break

msgspec = sys.modules.get("msgspec", None)
if msgspec is None:
return type_, {}

metadata_list = getattr(type_, "__metadata__", [])

metadata: dict[str, Any] = {}
for meta in metadata_list:
if not isinstance(meta, msgspec.Meta):
continue
single_meta: dict[str, Any] = getattr(meta, "extra", {}).get(
PSYGNAL_METADATA, {}
)
metadata.update(single_meta)

return type_, metadata


def iter_fields_with_options(
cls: type, exclude_frozen: bool = True
) -> Iterator[FieldOptions]:
"""Iterate over all fields in the class, return a field description.

This function recognizes dataclasses, attrs classes, msgspec Structs, and pydantic
models.

Parameters
----------
cls : type
The class to iterate over.
exclude_frozen : bool, optional
If True, frozen fields will be excluded. By default True.

Yields
------
FieldOptions
A dataclass instance with the name, type and metadata of each field.
"""
# Add metadata for dataclasses.dataclass
dclass_fields = getattr(cls, "__dataclass_fields__", None)
if dclass_fields is not None:
"""
Example
-------
from dataclasses import dataclass, field


@dataclass
class Foo:
bar: int = field(metadata={"alias": "bar_alias"})

assert (
Foo.__dataclass_fields__["bar"].metadata ==
{"__psygnal_metadata": {"alias": "bar_alias"}}
)

"""
for d_field in dclass_fields.values():
if d_field._field_type is dataclasses._FIELD: # type: ignore [attr-defined]
metadata = getattr(d_field, "metadata", {}).get(PSYGNAL_METADATA, {})
metadata = sanitize_field_options_dict(metadata)
options = FieldOptions(d_field.name, d_field.type, **metadata)
yield options
return

# Add metadata for pydantic dataclass
if is_pydantic_model(cls):
"""
Example
-------
from typing import Annotated

from pydantic import BaseModel, Field


# Only works with Pydantic v2
class Foo(BaseModel):
bar: Annotated[
str,
{'__psygnal_metadata': {"alias": "bar_alias"}}
] = Field(...)

# Working with Pydantic v2 and partially with v1
# Alternative, using Field `json_schema_extra` keyword argument
class Bar(BaseModel):
bar: str = Field(
json_schema_extra={PSYGNAL_METADATA: {"alias": "bar_alias"}}
)


assert (
Foo.model_fields["bar"].metadata[0] ==
{"__psygnal_metadata": {"alias": "bar_alias"}}
)
assert (
Bar.model_fields["bar"].json_schema_extra ==
{"__psygnal_metadata": {"alias": "bar_alias"}}
)

"""
if hasattr(cls, "model_fields"):
# Pydantic v2
for field_name, p_field in cls.model_fields.items():
# skip frozen field
if exclude_frozen and p_field.frozen:
continue
metadata_list = getattr(p_field, "metadata", [])
metadata = {}
for field in metadata_list:
metadata.update(field.get(PSYGNAL_METADATA, {}))
# Compat with using Field `json_schema_extra` keyword argument
if isinstance(getattr(p_field, "json_schema_extra", None), Mapping):
meta_dict = cast(Mapping, p_field.json_schema_extra)
metadata.update(meta_dict.get(PSYGNAL_METADATA, {}))
metadata = sanitize_field_options_dict(metadata)
options = FieldOptions(field_name, p_field.annotation, **metadata)
yield options
return

else:
# Pydantic v1, metadata is not always working
for pv1_field in cls.__fields__.values(): # type: ignore [attr-defined]
# skip frozen field
if exclude_frozen and not pv1_field.field_info.allow_mutation:
continue
meta_dict = getattr(pv1_field.field_info, "extra", {}).get(
"json_schema_extra", {}
)
metadata = meta_dict.get(PSYGNAL_METADATA, {})

metadata = sanitize_field_options_dict(metadata)
options = FieldOptions(
pv1_field.name,
pv1_field.outer_type_,
**metadata,
)
yield options
return

# Add metadata for attrs dataclass
attrs_fields = getattr(cls, "__attrs_attrs__", None)
if attrs_fields is not None:
"""
Example
-------
from attrs import define, field


@define
class Foo:
bar: int = field(metadata={"alias": "bar_alias"})

assert (
Foo.__attrs_attrs__.bar.metadata ==
{"__psygnal_metadata": {"alias": "bar_alias"}}
)

"""
for a_field in attrs_fields:
metadata = getattr(a_field, "metadata", {}).get(PSYGNAL_METADATA, {})
metadata = sanitize_field_options_dict(metadata)
options = FieldOptions(a_field.name, a_field.type, **metadata)
yield options
return

# Add metadata for attrs dataclass
if is_msgspec_struct(cls):
"""
Example
-------
from typing import Annotated

from msgspec import Meta, Struct


class Foo(Struct):
bar: Annotated[
str,
Meta(extra={"__psygnal_metadata": {"alias": "bar_alias"}))
] = ""


print(Foo.__annotations__["bar"].__metadata__[0].extra)
# {"__psygnal_metadata": {"alias": "bar_alias"}}

"""
for m_field in cls.__struct_fields__:
try:
type_, metadata = get_msgspec_metadata(cls, m_field)
metadata = sanitize_field_options_dict(metadata)
except AttributeError:
msg = f"Cannot parse field metadata for {m_field}: {type_}"
# logger.exception(msg)
print(msg)
type_, metadata = None, {}
options = FieldOptions(m_field, type_, **metadata)
yield options
return
5 changes: 4 additions & 1 deletion src/psygnal/_evented_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from psygnal._group_descriptor import SignalGroupDescriptor

if TYPE_CHECKING:
from psygnal._group_descriptor import EqOperator, FieldAliasFunc
from psygnal._group_descriptor import ( # type: ignore[attr-defined]
EqOperator,
FieldAliasFunc,
)

__all__ = ["evented"]

Expand Down
Loading