Skip to content

Commit ee7fae4

Browse files
author
Flax Team
committed
Don't pass linen_meta_type argument when creating AxisMetadata subclasses.
PiperOrigin-RevId: 804851220
1 parent f4bc868 commit ee7fae4

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

flax/nnx/bridge/variables.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def is_vanilla_variable(vs: variablelib.Variable) -> bool:
9191
def to_linen_var(vs: variablelib.Variable) -> meta.AxisMetadata:
9292
metadata = vs.get_metadata()
9393
if 'linen_meta_type' in metadata:
94-
linen_type = metadata['linen_meta_type']
94+
metadata = dict(metadata)
95+
linen_type = metadata.pop('linen_meta_type')
9596
if hasattr(linen_type, 'from_nnx_metadata'):
9697
return linen_type.from_nnx_metadata({'value': vs.value, **metadata})
9798
return linen_type(vs.value, **metadata)

0 commit comments

Comments
 (0)