diff --git a/narwhals/utils.py b/narwhals/utils.py
index 7f1229793f..7261806898 100644
--- a/narwhals/utils.py
+++ b/narwhals/utils.py
@@ -387,11 +387,7 @@ def remove_suffix(text: str, suffix: str) -> str:  # pragma: no cover
 
 
 def flatten(args: Any) -> list[Any]:
-    if not args:
-        return []
-    if len(args) == 1 and _is_iterable(args[0]):
-        return args[0]  # type: ignore[no-any-return]
-    return args  # type: ignore[no-any-return]
+    return list(args[0] if (len(args) == 1 and _is_iterable(args[0])) else args)
 
 
 def tupleify(arg: Any) -> Any: