Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 50 additions & 30 deletions flax/nnx/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _collect_stats(
if id(value) in node_stats:
continue
elif isinstance(value, variablelib.Variable):
var_type = type(value)
var_type = value.var_type
if issubclass(var_type, nnx.RngState):
var_type = nnx.RngState
size_bytes = SizeBytes.from_any(value.get_value())
Expand Down Expand Up @@ -168,10 +168,32 @@ def inner(state, *args, **kwargs):
return f(model, *args, **kwargs)
return jax.vjp(inner, state, *args, **kwargs)

def _get_call_info(jitted, method_name, node_stats, obj, compute_flops: bool, *args, **kwargs):
e = jitted.lower(obj, *args, **kwargs)
flops = _get_flops(e) if compute_flops else None
outputs = e.lowered.out_info[2]
def _get_call_info(
jitted,
method_name,
node_stats,
obj,
compute_flops: bool,
compute_vjp_flops: bool,
args,
kwargs,
outputs,
):
if compute_flops:
e = jitted.lower(obj, *args, **kwargs)
flops = _get_flops(e)
else:
flops = None
if compute_vjp_flops:

def do_vjp(*args, **kwargs):
primals, f_vjp = _pure_nnx_vjp(jitted, obj, *args, **kwargs)
return f_vjp(primals)

e_vjp = jax.jit(do_vjp).lower(obj, *args, **kwargs)
vjp_flops = _get_flops(e_vjp)
else:
vjp_flops = None
output_repr = jax.tree.map(_to_dummy_array, outputs)
input_args_info, input_kwargs_info = jax.tree.map(
_to_dummy_array, (args, kwargs)
Expand All @@ -191,6 +213,7 @@ def _get_call_info(jitted, method_name, node_stats, obj, compute_flops: bool, *a
input_kwargs=input_kwargs_info,
outputs=output_repr,
flops=flops,
vjp_flops=vjp_flops,
)


Expand All @@ -211,8 +234,9 @@ def _argsave(tracer_args, f):
n = f.__name__
@wraps(f)
def wrapper(obj, *args, **kwargs):
tracer_args.append((obj, n, args, kwargs))
return f(obj, *args, **kwargs)
out = f(obj, *args, **kwargs)
tracer_args.append((obj, n, args, kwargs, out))
return out
return wrapper

def _overwrite_methods(env):
Expand Down Expand Up @@ -358,6 +382,8 @@ def tabulate(
_variable_types: set[type] = {
nnx.RngState # type: ignore[misc]
if isinstance(leaf, nnx.RngState)
else leaf.var_type
if isinstance(leaf, variablelib.Variable)
else type(leaf)
for _, leaf in nnx.to_flat_state(nnx.state(obj))
}
Expand All @@ -368,38 +394,32 @@ def tabulate(
env = _create_obj_env(object_types)

# Modify all the object's methods to save their Tracer arguments.
# tracer_args contains (object, name, args, kwargs) tuples.
tracer_args: list[tuple[tp.Any, str, tuple, dict[str, tp.Any]]] = []
saver_env = {k: _argsave(tracer_args, v) for k,v in env.items()}
_overwrite_methods(saver_env)

# tracer_args contains (object, name, args, kwargs, out) tuples.
# Add JIT calculation to each method. We can extract flops and output info from
# the lowered JITs. We'll only call these jitted values, which guarantees
# that each method will only be traced (and added to the table) once.
jits = {} # Maps (class, method_name) to jit
for key, value in saver_env.items():
jits[key] = nnx.jit(value)
tracer_args: list[tuple[tp.Any, str, tuple, dict[str, tp.Any], tp.Any]] = []
jits = {k: nnx.jit(_argsave(tracer_args, v)) for k, v in env.items()}
_overwrite_methods(jits)

# Trace the top function (which indirectly traces all the others)
jits[(type(obj), method)].trace(obj, *input_args, **input_kwargs)

# Get call_info
rows : list[CallInfo] = [_get_call_info(
jits[(type(object), name)], name, node_stats, object,
compute_flops, *args, **kwargs)
for (object, name, args, kwargs) in tracer_args]

# Add VJP flops if required. This needs to be done separately because calls to `_pure_nnx_vjp`
# can result in tracing the jitted functions a second time if there's shared structure.
# This would add items to `tracer_args`, resulting in duplicate rows in the table.
if compute_vjp_flops:
for i, row in enumerate(rows):
object, method_name, args, kwargs = tracer_args[i]
def do_vjp(*args, **kwargs):
primals, f_vjp = _pure_nnx_vjp(jits[(type(object), method_name)], *args, **kwargs)
return f_vjp(primals)
row.vjp_flops = _get_flops(jax.jit(do_vjp).lower(object, *args, **kwargs))
rows: list[CallInfo] = [
_get_call_info(
jits[(type(object), name)],
name,
node_stats,
object,
compute_flops,
compute_vjp_flops,
args,
kwargs,
out,
)
for (object, name, args, kwargs, out) in list(tracer_args)
]

# Restore the object's original methods
_overwrite_methods(env)
Expand Down
Loading