diff --git a/src/psygnal/_evented_decorator.py b/src/psygnal/_evented_decorator.py index 236015c8..c638eadd 100644 --- a/src/psygnal/_evented_decorator.py +++ b/src/psygnal/_evented_decorator.py @@ -1,12 +1,10 @@ +from __future__ import annotations + from typing import ( Any, Callable, - Dict, Literal, - Optional, - Type, TypeVar, - Union, overload, ) @@ -14,11 +12,9 @@ __all__ = ["evented"] -T = TypeVar("T", bound=Type) +T = TypeVar("T", bound=type) EqOperator = Callable[[Any, Any], bool] -PSYGNAL_GROUP_NAME = "_psygnal_group_" -_NULL = object() @overload @@ -26,7 +22,7 @@ def evented( cls: T, *, events_namespace: str = "events", - equality_operators: Optional[Dict[str, EqOperator]] = None, + equality_operators: dict[str, EqOperator] | None = None, warn_on_no_fields: bool = ..., cache_on_instance: bool = ..., ) -> T: ... @@ -34,23 +30,23 @@ def evented( @overload def evented( - cls: "Optional[Literal[None]]" = None, + cls: Literal[None] | None = None, *, events_namespace: str = "events", - equality_operators: Optional[Dict[str, EqOperator]] = None, + equality_operators: dict[str, EqOperator] | None = None, warn_on_no_fields: bool = ..., cache_on_instance: bool = ..., ) -> Callable[[T], T]: ... def evented( - cls: Optional[T] = None, + cls: T | None = None, *, events_namespace: str = "events", - equality_operators: Optional[Dict[str, EqOperator]] = None, + equality_operators: dict[str, EqOperator] | None = None, warn_on_no_fields: bool = True, cache_on_instance: bool = True, -) -> Union[Callable[[T], T], T]: +) -> Callable[[T], T] | T: """A decorator to add events to a dataclass. See also the documentation for @@ -71,7 +67,7 @@ def evented( The class to decorate. events_namespace : str The name of the namespace to add the events to, by default `"events"` - equality_operators : Optional[Dict[str, Callable]] + equality_operators : dict[str, Callable] | None A dictionary mapping field names to equality operators (a function that takes two values and returns `True` if they are equal). These will be used to determine if a field has changed when setting a new value. By default, this @@ -122,7 +118,7 @@ def _decorate(cls: T) -> T: if any(k.startswith("_psygnal") for k in getattr(cls, "__annotations__", {})): raise TypeError("Fields on an evented class cannot start with '_psygnal'") - descriptor = SignalGroupDescriptor( + descriptor: SignalGroupDescriptor = SignalGroupDescriptor( equality_operators=equality_operators, warn_on_no_fields=warn_on_no_fields, cache_on_instance=cache_on_instance, diff --git a/src/psygnal/_group_descriptor.py b/src/psygnal/_group_descriptor.py index 5a836de5..9132b2dc 100644 --- a/src/psygnal/_group_descriptor.py +++ b/src/psygnal/_group_descriptor.py @@ -1,11 +1,11 @@ from __future__ import annotations import contextlib +import copy import operator import sys import warnings import weakref -from functools import lru_cache from typing import ( TYPE_CHECKING, Any, @@ -34,6 +34,7 @@ T = TypeVar("T", bound=Type) S = TypeVar("S") + EqOperator = Callable[[Any, Any], bool] _EQ_OPERATORS: dict[type, dict[str, EqOperator]] = {} _EQ_OPERATOR_NAME = "__eq_operators__" @@ -141,11 +142,25 @@ def connect_setattr( ) -@lru_cache(maxsize=None) def _build_dataclass_signal_group( - cls: type, equality_operators: Iterable[tuple[str, EqOperator]] | None = None + cls: type, + signal_group_class: type[SignalGroup], + equality_operators: Iterable[tuple[str, EqOperator]] | None = None, ) -> type[SignalGroup]: - """Build a SignalGroup with events for each field in a dataclass.""" + """Build a SignalGroup with events for each field in a dataclass. + + Parameters + ---------- + cls : type + the dataclass to look for the fields to connect with signals. + signal_group_class: type[SignalGroup] + SignalGroup or a subclass of it, to use as a super class. + Default to SignalGroup + equality_operators: Iterable[tuple[str, EqOperator]] | None + If defined, a mapping of field name and equality operator to use to compare if + each field was modified after being set. + Default to None + """ _equality_operators = dict(equality_operators) if equality_operators else {} signals = {} eq_map = _get_eq_operator_map(cls) @@ -162,7 +177,9 @@ def _build_dataclass_signal_group( # patch in our custom SignalInstance class with maxargs=1 on connect_setattr sig._signal_instance_class = _DataclassFieldSignalInstance - return type(f"{cls.__name__}SignalGroup", (SignalGroup,), signals) + # Create `signal_group_class` subclass with the attached signals + group_name = f"{cls.__name__}{signal_group_class.__name__}" + return type(group_name, (signal_group_class,), signals) def is_evented(obj: object) -> bool: @@ -335,8 +352,6 @@ def __setattr__(self, name: str, value: Any) -> None: field to determine whether to emit an event. If not provided, the default equality operator is `operator.eq`, except for numpy arrays, where `np.array_equal` is used. - signal_group_class : type[SignalGroup], optional - A custom SignalGroup class to use, by default None warn_on_no_fields : bool, optional If `True` (the default), a warning will be emitted if no mutable dataclass-like fields are found on the object. @@ -352,6 +367,13 @@ def __setattr__(self, name: str, value: Any) -> None: events when fields change. If `False`, no `__setattr__` method will be created. (This will prevent signal emission, and assumes you are using a different mechanism to emit signals when fields change.) + signal_group_class : type[SignalGroup] | None, optional + A custom SignalGroup class to use, SignalGroup if None, by default None + collect_fields : bool, optional + Create a signal for each field in the dataclass. If True, the `SignalGroup` + instance will be a subclass of `signal_group_class` (SignalGroup if it is None). + If False, a deepcopy of `signal_group_class` will be used. + Default to True Examples -------- @@ -374,22 +396,42 @@ class Person: ``` """ + # map of id(obj) -> SignalGroup + # cached here in case the object isn't modifiable + _instance_map: ClassVar[dict[int, SignalGroup]] = {} + def __init__( self, *, equality_operators: dict[str, EqOperator] | None = None, - signal_group_class: type[SignalGroup] | None = None, warn_on_no_fields: bool = True, cache_on_instance: bool = True, patch_setattr: bool = True, + signal_group_class: type[SignalGroup] | None = None, + collect_fields: bool = True, ): - self._signal_group = signal_group_class + grp_cls = signal_group_class or SignalGroup + if not (isinstance(grp_cls, type) and issubclass(grp_cls, SignalGroup)): + raise TypeError( # pragma: no cover + f"'signal_group_class' must be a subclass of SignalGroup, " + f"not {grp_cls}" + ) + if grp_cls is SignalGroup and collect_fields is False: + raise ValueError( + "Cannot use SignalGroup with collect_fields=False. " + "Use a custom SignalGroup subclass instead." + ) + self._name: str | None = None self._eqop = tuple(equality_operators.items()) if equality_operators else None self._warn_on_no_fields = warn_on_no_fields self._cache_on_instance = cache_on_instance self._patch_setattr = patch_setattr + self._signal_group_class: type[SignalGroup] = grp_cls + self._collect_fields = collect_fields + self._signal_groups: dict[int, type[SignalGroup]] = {} + def __set_name__(self, owner: type, name: str) -> None: """Called when this descriptor is added to class `owner` as attribute `name`.""" self._name = name @@ -417,10 +459,6 @@ def _do_patch_setattr(self, owner: type) -> None: "emitted when fields change." ) from e - # map of id(obj) -> SignalGroup - # cached here in case the object isn't modifiable - _instance_map: ClassVar[dict[int, SignalGroup]] = {} - @overload def __get__(self, instance: None, owner: type) -> SignalGroupDescriptor: ... @@ -434,13 +472,15 @@ def __get__( if instance is None: return self + signal_group = self._get_signal_group(owner) + # if we haven't yet instantiated a SignalGroup for this instance, # do it now and cache it. Note that we cache it here in addition to # the instance (in case the instance is not modifiable). obj_id = id(instance) if obj_id not in self._instance_map: # cache it - self._instance_map[obj_id] = self._create_group(owner)(instance) + self._instance_map[obj_id] = signal_group(instance) # also *try* to set it on the instance as well, since it will skip all the # __get__ logic in the future, but if it fails, no big deal. if self._name and self._cache_on_instance: @@ -453,13 +493,28 @@ def __get__( return self._instance_map[obj_id] + def _get_signal_group(self, owner: type) -> type[SignalGroup]: + type_id = id(owner) + if type_id not in self._signal_groups: + self._signal_groups[type_id] = self._create_group(owner) + return self._signal_groups[type_id] + def _create_group(self, owner: type) -> type[SignalGroup]: - Group = self._signal_group or _build_dataclass_signal_group(owner, self._eqop) + # Do not collect fields from owner class, copy the SignalGroup + if not self._collect_fields: + Group = copy.deepcopy(self._signal_group_class) + + # Collect fields and create SignalGroup subclass + else: + Group = _build_dataclass_signal_group( + owner, self._signal_group_class, equality_operators=self._eqop + ) if self._warn_on_no_fields and not Group._psygnal_signals: warnings.warn( f"No mutable fields found on class {owner}: no events will be " "emitted. (Is this a dataclass, attrs, msgspec, or pydantic model?)", stacklevel=2, ) + self._do_patch_setattr(owner) return Group diff --git a/tests/test_group_descriptor.py b/tests/test_group_descriptor.py index 1b36f032..d54805c1 100644 --- a/tests/test_group_descriptor.py +++ b/tests/test_group_descriptor.py @@ -1,10 +1,21 @@ +from contextlib import nullcontext from dataclasses import dataclass -from typing import Any, ClassVar +from typing import Any, ClassVar, Optional, Type from unittest.mock import Mock, patch import pytest -from psygnal import SignalGroupDescriptor, _compiled, _group_descriptor +from psygnal import ( + Signal, + SignalGroup, + SignalGroupDescriptor, + _compiled, + _group_descriptor, +) + + +class MyGroup(SignalGroup): + sig = Signal() @pytest.mark.parametrize("type_", ["dataclass", "pydantic", "attrs", "msgspec"]) @@ -213,3 +224,50 @@ class Bar: # when using connect_setattr with maxargs=None # remove this test if/when we change maxargs to default to 1 on SignalInstance assert bar.y == (2, 1) # type: ignore + + +@pytest.mark.parametrize("collect", [True, False]) +@pytest.mark.parametrize("klass", [None, SignalGroup, MyGroup]) +def test_collect_fields(collect: bool, klass: Optional[Type[SignalGroup]]) -> None: + signal_class = klass or SignalGroup + should_fail_def = signal_class is SignalGroup and collect is False + ctx = pytest.raises(ValueError) if should_fail_def else nullcontext() + + with ctx: + + @dataclass + class Foo: + events: ClassVar = SignalGroupDescriptor( + warn_on_no_fields=False, + signal_group_class=klass, + collect_fields=collect, + ) + a: int = 1 + + if should_fail_def: + return + + @dataclass + class Bar(Foo): + b: float = 2.0 + + foo = Foo() + bar = Bar() + + assert issubclass(type(foo.events), signal_class) + + if collect: + assert type(foo.events) is not signal_class + assert "a" in foo.events + assert "a" in bar.events + assert "b" in bar.events + + else: + assert type(foo.events) == signal_class + assert "a" not in foo.events + assert "a" not in bar.events + assert "b" not in bar.events + + if signal_class is MyGroup: + assert "sig" in foo.events + assert "sig" in bar.events