Commit 68e5a47
committed
[JAX] inspect_array: make
The prior commit that threaded a probe `name` through the
InspectPrimitive declared it as a keyword-only argument (`*, name`)
on `impl` / `partition` and left `impl_static_args = ()`. With
JAX's `custom_partitioning`, that breaks at trace time:
TypeError: keyword arguments could not be resolved to positions
`register_primitive` wraps `cls.impl` as
`custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)`,
and the wrapper's `__call__` runs `_resolve_kwargs(self.fun, args,
kwargs)` to push bind-time kwargs back into positional slots. A
keyword-only parameter has no positional slot to resolve into, so any
`outer_primitive.bind(x, ..., name=name)` call from inside a
`jax.jit` aborts before the FFI ever runs.
Follow the established TE pattern (e.g. `ActLuPrimitive`,
`FusedMoEAuxLossBwdPrimitive`):
* Set `impl_static_args = (5,)` to declare position of `name`.
* Drop the `*` separator on `impl` so `name` is
positional-or-keyword; `_resolve_kwargs` can now push the
bind kwarg to position 5.
* Move static args to the head of `partition(name, mesh, ...)`
and `shardy_sharding_rule(name, mesh, ...)` per
`custom_partitioning`'s convention when `static_argnums`
is set on the wrapped impl.
* `abstract` and `lowering` keep `*, name` because they are
called by JAX directly with bind kwargs and never go through
`_resolve_kwargs`.
The user-facing API (`inspect_array(x, name)`) and the
`outer_primitive.bind(x, ..., name=name)` call site are unchanged.name positional so custom_partitioning can resolve it1 parent b697a86 commit 68e5a47
1 file changed
Lines changed: 21 additions & 10 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
25 | | - | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
26 | 32 | | |
27 | 33 | | |
28 | 34 | | |
| |||
89 | 95 | | |
90 | 96 | | |
91 | 97 | | |
92 | | - | |
93 | 98 | | |
94 | 99 | | |
95 | 100 | | |
96 | | - | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
97 | 104 | | |
98 | 105 | | |
99 | | - | |
| 106 | + | |
100 | 107 | | |
101 | 108 | | |
102 | 109 | | |
| |||
107 | 114 | | |
108 | 115 | | |
109 | 116 | | |
110 | | - | |
| 117 | + | |
111 | 118 | | |
112 | 119 | | |
113 | 120 | | |
114 | 121 | | |
115 | 122 | | |
116 | 123 | | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
117 | 127 | | |
118 | 128 | | |
119 | 129 | | |
| |||
128 | 138 | | |
129 | 139 | | |
130 | 140 | | |
131 | | - | |
| 141 | + | |
132 | 142 | | |
133 | 143 | | |
134 | 144 | | |
135 | 145 | | |
136 | | - | |
| 146 | + | |
137 | 147 | | |
138 | 148 | | |
139 | 149 | | |
140 | | - | |
141 | | - | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
142 | 153 | | |
143 | | - | |
| 154 | + | |
144 | 155 | | |
145 | 156 | | |
146 | 157 | | |
| |||
0 commit comments