diff --git a/glass/_array_api_utils.py b/glass/_array_api_utils.py index 5f8fa992..10284068 100644 --- a/glass/_array_api_utils.py +++ b/glass/_array_api_utils.py @@ -16,7 +16,11 @@ def get_namespace(*arrays: NDArray[Any] | Array) -> ModuleType: if they do not. """ namespace = arrays[0].__array_namespace__() - if not all(array.__array_namespace__() == namespace for array in arrays): + if any( + array.__array_namespace__() != namespace + for array in arrays + if array is not None + ): msg = "input arrays should belong to the same array library" raise ValueError(msg)