Skip to content

Commit 68e5a47

Browse files
committed
[JAX] inspect_array: make name positional so custom_partitioning can resolve it
The prior commit that threaded a probe `name` through the InspectPrimitive declared it as a keyword-only argument (`*, name`) on `impl` / `partition` and left `impl_static_args = ()`. With JAX's `custom_partitioning`, that breaks at trace time: TypeError: keyword arguments could not be resolved to positions `register_primitive` wraps `cls.impl` as `custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)`, and the wrapper's `__call__` runs `_resolve_kwargs(self.fun, args, kwargs)` to push bind-time kwargs back into positional slots. A keyword-only parameter has no positional slot to resolve into, so any `outer_primitive.bind(x, ..., name=name)` call from inside a `jax.jit` aborts before the FFI ever runs. Follow the established TE pattern (e.g. `ActLuPrimitive`, `FusedMoEAuxLossBwdPrimitive`): * Set `impl_static_args = (5,)` to declare position of `name`. * Drop the `*` separator on `impl` so `name` is positional-or-keyword; `_resolve_kwargs` can now push the bind kwarg to position 5. * Move static args to the head of `partition(name, mesh, ...)` and `shardy_sharding_rule(name, mesh, ...)` per `custom_partitioning`'s convention when `static_argnums` is set on the wrapped impl. * `abstract` and `lowering` keep `*, name` because they are called by JAX directly with bind kwargs and never go through `_resolve_kwargs`. The user-facing API (`inspect_array(x, name)`) and the `outer_primitive.bind(x, ..., name=name)` call site are unchanged.
1 parent b697a86 commit 68e5a47

1 file changed

Lines changed: 21 additions & 10 deletions

File tree

  • transformer_engine/jax/debug/experimental

transformer_engine/jax/debug/experimental/inspect.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@ class InspectPrimitive(BasePrimitive):
2222

2323
name = "te_inspect_ffi"
2424
multiple_results = False
25-
impl_static_args = ()
25+
# ``name`` is at positional index 5 in ``impl``; declaring it static
26+
# here lets ``custom_partitioning`` resolve ``bind(..., name=...)``
27+
# kwargs back to that position via ``_resolve_kwargs``. Keyword-only
28+
# parameters (``*, name``) would raise
29+
# TypeError: keyword arguments could not be resolved to positions
30+
# at trace time, so ``impl`` accepts ``name`` positionally below.
31+
impl_static_args = (5,)
2632
inner_primitive = None
2733
outer_primitive = None
2834

@@ -89,14 +95,15 @@ def impl(
8995
x_max,
9096
x_mean,
9197
x_std,
92-
*,
9398
name,
9499
):
95100
"""
96-
inspect implementation
101+
inspect implementation. ``name`` is positional (not keyword-only)
102+
so ``custom_partitioning(static_argnums=(5,))`` can resolve
103+
``bind(..., name=...)`` kwargs to position 5.
97104
"""
98105
assert InspectPrimitive.inner_primitive is not None
99-
(x) = InspectPrimitive.inner_primitive.bind(
106+
x = InspectPrimitive.inner_primitive.bind(
100107
x,
101108
x_min,
102109
x_max,
@@ -107,13 +114,16 @@ def impl(
107114
return x
108115

109116
@staticmethod
110-
def partition(mesh, arg_infos, result_infos, *, name):
117+
def partition(name, mesh, arg_infos, result_infos):
111118
"""
112119
Identity in sharding: the output carries the same sharding as ``x``;
113120
the four scalar stats (x_min, x_max, x_mean, x_std) are fully
114121
replicated. Without this override the primitive falls back to
115122
``BasePrimitive``'s abstract partition and any multi-device JIT
116123
rejects the call.
124+
125+
Static args precede ``mesh`` per the ``custom_partitioning``
126+
convention when ``static_argnums`` is set on the wrapped impl.
117127
"""
118128
del result_infos
119129
x_sharding = arg_infos[0].sharding
@@ -128,19 +138,20 @@ def partition(mesh, arg_infos, result_infos, *, name):
128138
out_sharding = x_sharding
129139

130140
def sharded_impl(x, x_min, x_max, x_mean, x_std):
131-
return InspectPrimitive.impl(x, x_min, x_max, x_mean, x_std, name=name)
141+
return InspectPrimitive.impl(x, x_min, x_max, x_mean, x_std, name)
132142

133143
return mesh, sharded_impl, out_sharding, arg_shardings
134144

135145
@staticmethod
136-
def shardy_sharding_rule(*args, **kwargs):
146+
def shardy_sharding_rule(*args):
137147
"""
138148
Five operands, one output. ``x`` and the output carry the same
139149
wildcard rank; the four scalar stats are rank-0 (empty operand
140-
entries between commas). The ``name`` keyword attribute does not
141-
participate in the rule.
150+
entries between commas). ``name`` is a static arg (precedes
151+
``mesh``/``value_types``/``result_types`` in ``args``) and does
152+
not participate in the rule.
142153
"""
143-
del args, kwargs
154+
del args
144155
return "..., , , , -> ..."
145156

146157

0 commit comments

Comments
 (0)