diff --git a/flax/nnx/summary.py b/flax/nnx/summary.py index 2789732d3..82d65db99 100644 --- a/flax/nnx/summary.py +++ b/flax/nnx/summary.py @@ -93,7 +93,7 @@ def _collect_stats( if id(value) in node_stats: continue elif isinstance(value, variablelib.Variable): - var_type = type(value) + var_type = value.var_type if issubclass(var_type, nnx.RngState): var_type = nnx.RngState size_bytes = SizeBytes.from_any(value.get_value()) @@ -168,10 +168,32 @@ def inner(state, *args, **kwargs): return f(model, *args, **kwargs) return jax.vjp(inner, state, *args, **kwargs) -def _get_call_info(jitted, method_name, node_stats, obj, compute_flops: bool, *args, **kwargs): - e = jitted.lower(obj, *args, **kwargs) - flops = _get_flops(e) if compute_flops else None - outputs = e.lowered.out_info[2] +def _get_call_info( + jitted, + method_name, + node_stats, + obj, + compute_flops: bool, + compute_vjp_flops: bool, + args, + kwargs, + outputs, +): + if compute_flops: + e = jitted.lower(obj, *args, **kwargs) + flops = _get_flops(e) + else: + flops = None + if compute_vjp_flops: + + def do_vjp(*args, **kwargs): + primals, f_vjp = _pure_nnx_vjp(jitted, obj, *args, **kwargs) + return f_vjp(primals) + + e_vjp = jax.jit(do_vjp).lower(obj, *args, **kwargs) + vjp_flops = _get_flops(e_vjp) + else: + vjp_flops = None output_repr = jax.tree.map(_to_dummy_array, outputs) input_args_info, input_kwargs_info = jax.tree.map( _to_dummy_array, (args, kwargs) @@ -191,6 +213,7 @@ def _get_call_info(jitted, method_name, node_stats, obj, compute_flops: bool, *a input_kwargs=input_kwargs_info, outputs=output_repr, flops=flops, + vjp_flops=vjp_flops, ) @@ -211,8 +234,9 @@ def _argsave(tracer_args, f): n = f.__name__ @wraps(f) def wrapper(obj, *args, **kwargs): - tracer_args.append((obj, n, args, kwargs)) - return f(obj, *args, **kwargs) + out = f(obj, *args, **kwargs) + tracer_args.append((obj, n, args, kwargs, out)) + return out return wrapper def _overwrite_methods(env): @@ -358,6 +382,8 @@ def tabulate( _variable_types: set[type] = { nnx.RngState # type: ignore[misc] if isinstance(leaf, nnx.RngState) + else leaf.var_type + if isinstance(leaf, variablelib.Variable) else type(leaf) for _, leaf in nnx.to_flat_state(nnx.state(obj)) } @@ -368,38 +394,32 @@ def tabulate( env = _create_obj_env(object_types) # Modify all the object's methods to save their Tracer arguments. - # tracer_args contains (object, name, args, kwargs) tuples. - tracer_args: list[tuple[tp.Any, str, tuple, dict[str, tp.Any]]] = [] - saver_env = {k: _argsave(tracer_args, v) for k,v in env.items()} - _overwrite_methods(saver_env) - + # tracer_args contains (object, name, args, kwargs, out) tuples. # Add JIT calculation to each method. We can extract flops and output info from # the lowered JITs. We'll only call these jitted values, which guarantees # that each method will only be traced (and added to the table) once. - jits = {} # Maps (class, method_name) to jit - for key, value in saver_env.items(): - jits[key] = nnx.jit(value) + tracer_args: list[tuple[tp.Any, str, tuple, dict[str, tp.Any], tp.Any]] = [] + jits = {k: nnx.jit(_argsave(tracer_args, v)) for k, v in env.items()} _overwrite_methods(jits) # Trace the top function (which indirectly traces all the others) jits[(type(obj), method)].trace(obj, *input_args, **input_kwargs) # Get call_info - rows : list[CallInfo] = [_get_call_info( - jits[(type(object), name)], name, node_stats, object, - compute_flops, *args, **kwargs) - for (object, name, args, kwargs) in tracer_args] - - # Add VJP flops if required. This needs to be done separately because calls to `_pure_nnx_vjp` - # can result in tracing the jitted functions a second time if there's shared structure. - # This would add items to `tracer_args`, resulting in duplicate rows in the table. - if compute_vjp_flops: - for i, row in enumerate(rows): - object, method_name, args, kwargs = tracer_args[i] - def do_vjp(*args, **kwargs): - primals, f_vjp = _pure_nnx_vjp(jits[(type(object), method_name)], *args, **kwargs) - return f_vjp(primals) - row.vjp_flops = _get_flops(jax.jit(do_vjp).lower(object, *args, **kwargs)) + rows: list[CallInfo] = [ + _get_call_info( + jits[(type(object), name)], + name, + node_stats, + object, + compute_flops, + compute_vjp_flops, + args, + kwargs, + out, + ) + for (object, name, args, kwargs, out) in list(tracer_args) + ] # Restore the object's original methods _overwrite_methods(env)