Skip to content

Commit b697a86

Browse files
committed
[JAX] inspect_array: thread probe name through FFI to per-probe filenames
Previously the FFI hardcoded the output path to my_tensor_gpu{N}.bin and ignored the `name` argument on the Python side, so every probe call in a program overwrote the same files; the only surviving on-disk dumps were whichever probe happened to fire last per rank. That made multi-probe debugging (e.g. wiring TE_MOE_INSPECT through several fwd and bwd steps of an MoE block) impossible to do offline -- only the live printf log could be correlated, and only by shape/dtype. Pass `name` through as an XLA FFI string attribute. On the C++ side it gets sanitised to a POSIX-safe filename component ({[A-Za-z0-9._-]} preserved, everything else mapped to `_`) and used as a suffix: my_tensor_gpu{device}_{sanitized_name}.bin my_tensor_gpu{device}_{sanitized_name}_meta.json The unsanitised name is echoed verbatim in the JSON metadata and in the printed log line so probe identity survives the rename. On the Python side `name` is carried as a custom_vjp nondiff arg, threaded into the InspectPrimitive bind as a static kwarg, and surfaced through abstract / lowering / impl / partition / shardy_sharding_rule.
1 parent adc3227 commit b697a86

2 files changed

Lines changed: 98 additions & 40 deletions

File tree

transformer_engine/jax/csrc/extensions/inspect.cpp

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,43 @@
55
************************************************************************/
66
#include <cuda_runtime.h>
77

8+
#include <algorithm>
89
#include <fstream>
910
#include <iostream>
11+
#include <string>
12+
#include <string_view>
1013

1114
#include "../extensions.h"
1215
#include "xla/ffi/api/c_api.h"
1316

1417
namespace transformer_engine {
1518
namespace jax {
1619

20+
// Sanitize a probe name for use as a filename component: replace any
21+
// character that's not [A-Za-z0-9._-] with '_'. Probe names like
22+
// "fwd/sparse_probs_after_fused_topk" therefore become legal POSIX
23+
// filenames ("fwd_sparse_probs_after_fused_topk") without losing the
24+
// trailing semantic suffix.
25+
static std::string SanitizeProbeName(std::string_view name) {
26+
std::string out;
27+
out.reserve(name.size());
28+
for (char c : name) {
29+
if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '.' ||
30+
c == '_' || c == '-') {
31+
out.push_back(c);
32+
} else {
33+
out.push_back('_');
34+
}
35+
}
36+
if (out.empty()) {
37+
out = "anon";
38+
}
39+
return out;
40+
}
41+
1742
Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type min_buf,
1843
Buffer_Type max_buf, Buffer_Type mean_buf, Buffer_Type std_buf,
19-
Result_Type output_buf) {
44+
Result_Type output_buf, std::string_view name) {
2045
NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation");
2146
NVTE_CHECK(output_buf->untyped_data() != nullptr,
2247
"Output must be provided for inspect operation");
@@ -42,18 +67,25 @@ Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type mi
4267
int device;
4368
NVTE_CHECK_CUDA(cudaGetDevice(&device));
4469

45-
// Write the tensor data to a file as a binary blob
46-
std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin";
70+
// Per-probe filenames: my_tensor_gpu{device}_{sanitized_name}.bin /
71+
// ..._meta.json. With distinct names, the on-disk dumps survive across
72+
// probes instead of being overwritten on every call, so a single test
73+
// run produces one .bin per probe per rank ready for offline analysis.
74+
std::string safe_name = SanitizeProbeName(name);
75+
std::string device_str = std::to_string(device);
76+
std::string filename = "my_tensor_gpu" + device_str + "_" + safe_name + ".bin";
4777
std::ofstream file(filename, std::ios::binary);
4878
NVTE_CHECK(file.is_open(), "Failed to create file: ", filename);
4979
file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
5080
file.close();
5181

52-
// Write out a metadata file
53-
std::string meta_filename = "my_tensor_gpu" + std::to_string(device) + "_meta.json";
82+
std::string meta_filename = "my_tensor_gpu" + device_str + "_" + safe_name + "_meta.json";
5483
std::ofstream meta_file(meta_filename);
5584
NVTE_CHECK(meta_file.is_open(), "Failed to create file: ", meta_filename);
5685
meta_file << "{";
86+
// Echo the original (un-sanitized) probe name so analysis tools can
87+
// recover the semantic label even when the filename had to mangle it.
88+
meta_file << "\"name\": \"" << name << "\", ";
5789
meta_file << "\"shape\": [";
5890
for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
5991
meta_file << input_buf.dimensions()[i];
@@ -70,8 +102,11 @@ Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type mi
70102
meta_file << "}";
71103
meta_file.close();
72104

73-
// Log the tensor metadata to the console
74-
printf("[gpu%d]: Tensor data written to %s (shape: [", device, filename.c_str());
105+
// Surface the probe name in the live log alongside the file path, so
106+
// analysing a multi-probe trace doesn't require correlating by
107+
// shape/dtype guesswork.
108+
printf("[gpu%d %.*s]: written to %s (shape: [", device, static_cast<int>(name.size()),
109+
name.data(), filename.c_str());
75110
for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
76111
printf("%zu", static_cast<size_t>(input_buf.dimensions()[i]));
77112
if (i < input_buf.dimensions().size() - 1) {
@@ -86,13 +121,14 @@ Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type mi
86121

87122
XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI,
88123
FFI::Bind()
89-
.Ctx<FFI_Stream_Type>() // stream
90-
.Arg<Buffer_Type>() // input
91-
.Arg<Buffer_Type>() // min
92-
.Arg<Buffer_Type>() // max
93-
.Arg<Buffer_Type>() // mean
94-
.Arg<Buffer_Type>() // std
95-
.Ret<Buffer_Type>() // output
124+
.Ctx<FFI_Stream_Type>() // stream
125+
.Arg<Buffer_Type>() // input
126+
.Arg<Buffer_Type>() // min
127+
.Arg<Buffer_Type>() // max
128+
.Arg<Buffer_Type>() // mean
129+
.Arg<Buffer_Type>() // std
130+
.Ret<Buffer_Type>() // output
131+
.Attr<std::string_view>("name") // probe name
96132
);
97133

98134
} // namespace jax

transformer_engine/jax/debug/experimental/inspect.py

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,13 @@ def abstract(
3333
x_max_aval,
3434
x_mean_aval,
3535
x_std_aval,
36+
*,
37+
name,
3638
):
3739
"""
3840
inspect abstract
3941
"""
42+
del name
4043
assert (
4144
x_min_aval.shape == () and x_min_aval.dtype == jnp.float32
4245
), "x_min must be a scalar with dtype float32"
@@ -59,6 +62,8 @@ def lowering(
5962
x_max,
6063
x_mean,
6164
x_std,
65+
*,
66+
name,
6267
):
6368
"""
6469
inspect lowering rules
@@ -74,6 +79,7 @@ def lowering(
7479
x_max,
7580
x_mean,
7681
x_std,
82+
name=name,
7783
)
7884

7985
@staticmethod
@@ -83,6 +89,8 @@ def impl(
8389
x_max,
8490
x_mean,
8591
x_std,
92+
*,
93+
name,
8694
):
8795
"""
8896
inspect implementation
@@ -94,11 +102,12 @@ def impl(
94102
x_max,
95103
x_mean,
96104
x_std,
105+
name=name,
97106
)
98107
return x
99108

100109
@staticmethod
101-
def partition(mesh, arg_infos, result_infos):
110+
def partition(mesh, arg_infos, result_infos, *, name):
102111
"""
103112
Identity in sharding: the output carries the same sharding as ``x``;
104113
the four scalar stats (x_min, x_max, x_mean, x_std) are fully
@@ -119,25 +128,26 @@ def partition(mesh, arg_infos, result_infos):
119128
out_sharding = x_sharding
120129

121130
def sharded_impl(x, x_min, x_max, x_mean, x_std):
122-
return InspectPrimitive.impl(x, x_min, x_max, x_mean, x_std)
131+
return InspectPrimitive.impl(x, x_min, x_max, x_mean, x_std, name=name)
123132

124133
return mesh, sharded_impl, out_sharding, arg_shardings
125134

126135
@staticmethod
127-
def shardy_sharding_rule(*args):
136+
def shardy_sharding_rule(*args, **kwargs):
128137
"""
129138
Five operands, one output. ``x`` and the output carry the same
130139
wildcard rank; the four scalar stats are rank-0 (empty operand
131-
entries between commas).
140+
entries between commas). The ``name`` keyword attribute does not
141+
participate in the rule.
132142
"""
133-
del args
143+
del args, kwargs
134144
return "..., , , , -> ..."
135145

136146

137147
register_primitive(InspectPrimitive)
138148

139149

140-
def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray:
150+
def _inspect_array_inner(x: jnp.ndarray, name: str) -> jnp.ndarray:
141151
assert InspectPrimitive.outer_primitive is not None, (
142152
"InspectPrimitive FFI is not registered. Please ensure the C++ extension is properly built"
143153
" and registered."
@@ -148,50 +158,62 @@ def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray:
148158
jnp.max(x).astype(jnp.float32),
149159
jnp.mean(x.astype(jnp.float32)),
150160
jnp.std(x.astype(jnp.float32)),
161+
name=name,
151162
)
152163

153164

154-
@partial(jax.custom_vjp, nondiff_argnums=())
155-
def _inspect(
156-
x,
157-
):
165+
# ``name`` is a Python string and must not be traced through jax — it is
166+
# carried as a custom_vjp nondiff argument so it stays static at compile
167+
# time, threads into the primitive bind as a kwarg, and lands on the
168+
# FFI as a string attribute.
169+
@partial(jax.custom_vjp, nondiff_argnums=(1,))
170+
def _inspect(x, name):
158171
""" """
159-
output, _ = _inspect_fwd_rule(
160-
x,
161-
)
172+
output, _ = _inspect_fwd_rule(x, name)
162173
return output
163174

164175

165-
def _inspect_fwd_rule(
166-
x,
167-
):
176+
def _inspect_fwd_rule(x, name):
168177
""""""
169178
ctx = ()
170-
x = _inspect_array_inner(x)
179+
x = _inspect_array_inner(x, name)
171180
return x, ctx
172181

173182

174-
def _inspect_bwd_rule(
175-
ctx,
176-
grad,
177-
):
183+
def _inspect_bwd_rule(name, ctx, grad):
178184
""""""
179-
del ctx
185+
del name, ctx
180186
return (grad,)
181187

182188

183189
_inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule)
184190

185191

186192
def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray:
187-
"""Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics.
193+
"""Inspect a JAX array by dumping its data and stats to disk per-rank.
194+
195+
On every call the FFI synchronises the input device buffer to host
196+
and writes two files per rank, **keyed by ``name``** so multiple
197+
probes in the same program produce distinct dumps:
198+
199+
* ``my_tensor_gpu{device}_{sanitized_name}.bin`` – raw bytes.
200+
* ``my_tensor_gpu{device}_{sanitized_name}_meta.json`` – ``name``,
201+
shape, dtype, and min/max/mean/std summary stats.
202+
203+
A line is also printed to stdout including the probe ``name`` so
204+
multi-probe traces are easy to follow in a live log.
205+
206+
``name`` is treated as a static (non-traced) attribute, so the same
207+
probe name must be passed in every (re-)trace of an enclosing
208+
``jax.jit``; characters outside ``[A-Za-z0-9._-]`` are mapped to
209+
``_`` when forming the filename, but the unsanitised name is echoed
210+
verbatim in the JSON metadata and the printed log line.
188211
189212
Args:
190213
x (jnp.ndarray): The JAX array to inspect.
191-
name (str): The name of the array for identification in the output.
214+
name (str): Identifier for this probe; used in filenames and logs.
192215
"""
193-
del name # Name is currently unused, but can be included in the future for more informative output
194-
return _inspect(x)
216+
return _inspect(x, name)
195217

196218

197219
def load_array_dump(filename: str, shape: tuple, dtype: jnp.dtype) -> jnp.ndarray:

0 commit comments

Comments
 (0)