From 6c7b0e4ca5bb4fa85fe4d970f86352b18107fc1b Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 13 Mar 2026 06:05:45 +0100 Subject: [PATCH 1/6] Fix lru cache --- src/gt4py/eve/utils.py | 13 +++++++++++-- tests/eve_tests/unit_tests/test_utils.py | 14 ++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index a0e48ae557..313ef3bd4e 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -464,6 +464,15 @@ def inner(*args: Any, **kwargs: Any) -> Any: return _decorator(func) if func is not None else _decorator +class _LRUCacheValue(HashableBy): + __hash__ = HashableBy.__hash__ + + # we compare by hash here as the cache lookup in lru_cache is not only by hash but also by + # equality, but we only want to consider the key given by the user not the value. + def __eq__(self, other: Any) -> bool: + return hash(self) == hash(other) + + # TODO(egparedes): it would be more efficient to implement the caching logic # here instead of relying on `functools.lru_cache` and wrapping/unwrapping the # arguments. @@ -504,8 +513,8 @@ def cached_func(*args: HashableBy, **kwargs: HashableBy) -> _T: @functools.wraps(func) def inner(*args, **kwargs): # type: ignore[no-untyped-def] # cast below restores type info return cached_func( - *(hashable_by(key, arg) for arg in args), - **{k: hashable_by(key, arg) for k, arg in kwargs.items()}, + *(_LRUCacheValue(key, arg) for arg in args), + **{k: _LRUCacheValue(key, arg) for k, arg in kwargs.items()}, ) inner.cache_parameters = cached_func.cache_parameters # type: ignore[attr-defined] # mypy not aware of functools.lru_cache behavior diff --git a/tests/eve_tests/unit_tests/test_utils.py b/tests/eve_tests/unit_tests/test_utils.py index ae8e938396..4033eab3cb 100644 --- a/tests/eve_tests/unit_tests/test_utils.py +++ b/tests/eve_tests/unit_tests/test_utils.py @@ -269,6 +269,20 @@ def func(x): assert cached.cache_info().hits == 1 assert cached.cache_info().misses == 1 +def test_lru_cache_no_eq_call(): + class A: + def __hash__(self) -> int: + return 1 + + def __eq__(self, other): + raise ValueError() # this function should never be called + + @eve.utils.lru_cache(key=lambda x: hash(x)) + def func(x): + pass + + func(A()) + func(A()) def test_fluid_partial(): from gt4py.eve.utils import fluid_partial From 849df8926112c0da42d7243a012251213e3fb20e Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 13 Mar 2026 06:07:08 +0100 Subject: [PATCH 2/6] Fix format --- tests/eve_tests/unit_tests/test_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/eve_tests/unit_tests/test_utils.py b/tests/eve_tests/unit_tests/test_utils.py index 4033eab3cb..63ddd6d68e 100644 --- a/tests/eve_tests/unit_tests/test_utils.py +++ b/tests/eve_tests/unit_tests/test_utils.py @@ -269,6 +269,7 @@ def func(x): assert cached.cache_info().hits == 1 assert cached.cache_info().misses == 1 + def test_lru_cache_no_eq_call(): class A: def __hash__(self) -> int: @@ -284,6 +285,7 @@ def func(x): func(A()) func(A()) + def test_fluid_partial(): from gt4py.eve.utils import fluid_partial From 17456a58bcaded631ca3410f74093c7edf5d6f26 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 20 Mar 2026 10:16:02 +0100 Subject: [PATCH 3/6] Update src/gt4py/eve/utils.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Enrique González Paredes --- src/gt4py/eve/utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 313ef3bd4e..46b42b8e36 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -464,13 +464,12 @@ def inner(*args: Any, **kwargs: Any) -> Any: return _decorator(func) if func is not None else _decorator -class _LRUCacheValue(HashableBy): +class EqualityBy(HashableBy): + """Use a hash function as the definition of equality for the wrapped object.""" __hash__ = HashableBy.__hash__ - # we compare by hash here as the cache lookup in lru_cache is not only by hash but also by - # equality, but we only want to consider the key given by the user not the value. def __eq__(self, other: Any) -> bool: - return hash(self) == hash(other) + return self is other or hash(self) == hash(other) # TODO(egparedes): it would be more efficient to implement the caching logic From 49498da13cf6132d64f0a574a92737f5cd70026b Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 20 Mar 2026 10:16:09 +0100 Subject: [PATCH 4/6] Update src/gt4py/eve/utils.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Enrique González Paredes --- src/gt4py/eve/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 46b42b8e36..400513101a 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -485,7 +485,8 @@ def lru_cache( """ Wrap :func:`functools.lru_cache` but allow customizing the cache key. - Be careful: `key(obj1) == key(obj2)` must imply `obj1 == obj2`. + Be careful, with custom `key` functions, `key(obj1) == key(obj2)` automatically + implies `obj1 == obj2`. >>> @lru_cache(key=id) ... def func(x): From 4f75ea686f98c311d0683d567581a792e008b030 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 20 Mar 2026 10:18:09 +0100 Subject: [PATCH 5/6] Update utils.py --- src/gt4py/eve/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 400513101a..14be84f020 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -486,7 +486,7 @@ def lru_cache( Wrap :func:`functools.lru_cache` but allow customizing the cache key. Be careful, with custom `key` functions, `key(obj1) == key(obj2)` automatically - implies `obj1 == obj2`. + implies `obj1 == obj2`, i.e. they are considered equal. >>> @lru_cache(key=id) ... def func(x): From dd38119bc01f3e766009a1cf996802023ac84af6 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 20 Mar 2026 12:56:40 +0100 Subject: [PATCH 6/6] Fix --- src/gt4py/eve/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 14be84f020..ce6967bd53 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -466,6 +466,7 @@ def inner(*args: Any, **kwargs: Any) -> Any: class EqualityBy(HashableBy): """Use a hash function as the definition of equality for the wrapped object.""" + __hash__ = HashableBy.__hash__ def __eq__(self, other: Any) -> bool: @@ -513,8 +514,8 @@ def cached_func(*args: HashableBy, **kwargs: HashableBy) -> _T: @functools.wraps(func) def inner(*args, **kwargs): # type: ignore[no-untyped-def] # cast below restores type info return cached_func( - *(_LRUCacheValue(key, arg) for arg in args), - **{k: _LRUCacheValue(key, arg) for k, arg in kwargs.items()}, + *(EqualityBy(key, arg) for arg in args), + **{k: EqualityBy(key, arg) for k, arg in kwargs.items()}, ) inner.cache_parameters = cached_func.cache_parameters # type: ignore[attr-defined] # mypy not aware of functools.lru_cache behavior