diff --git a/flax/nnx/bridge/wrappers.py b/flax/nnx/bridge/wrappers.py index 3f9599b43..9e7d610eb 100644 --- a/flax/nnx/bridge/wrappers.py +++ b/flax/nnx/bridge/wrappers.py @@ -226,7 +226,7 @@ def _get_module_method(module, method: tp.Callable[..., Any] | str | None): if not callable(method): class_name = type(module).__name__ raise TypeError( - f"'{method}' must be a callable, got {type(method)}." + f"'{class_name}.{method}' must be a callable, got {type(method)}." ) return method