From 9c419680ac1a30337c853aa21f236b226929ee6f Mon Sep 17 00:00:00 2001 From: Flax Team Date: Tue, 9 Sep 2025 04:40:35 -0700 Subject: [PATCH] Don't pass linen_meta_type argument when creating AxisMetadata subclasses. PiperOrigin-RevId: 804851220 --- flax/nnx/bridge/variables.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flax/nnx/bridge/variables.py b/flax/nnx/bridge/variables.py index 95ba46f89..2e3b5adbc 100644 --- a/flax/nnx/bridge/variables.py +++ b/flax/nnx/bridge/variables.py @@ -91,7 +91,8 @@ def is_vanilla_variable(vs: variablelib.Variable) -> bool: def to_linen_var(vs: variablelib.Variable) -> meta.AxisMetadata: metadata = vs.get_metadata() if 'linen_meta_type' in metadata: - linen_type = metadata['linen_meta_type'] + metadata = dict(metadata) + linen_type = metadata.pop('linen_meta_type') if hasattr(linen_type, 'from_nnx_metadata'): return linen_type.from_nnx_metadata({'value': vs.value, **metadata}) return linen_type(vs.value, **metadata)