Skip to content

Commit d8317b5

Browse files
[Pallas][Easy] Terser printing of GridMapping unless debug is set.
PiperOrigin-RevId: 768932310
1 parent d69086d commit d8317b5

File tree

3 files changed

+92
-21
lines changed

3 files changed

+92
-21
lines changed

jax/_src/pallas/core.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@ def to_block_mapping(
499499
index_map_tree: tree_util.PyTreeDef,
500500
grid: GridMappingGrid,
501501
mapped_dims: tuple[int, ...],
502+
debug: bool = False,
502503
) -> BlockMapping:
503504
if self.index_map is None:
504505
index_map_func = default_index_map(len(array_aval.shape))
@@ -539,11 +540,15 @@ def to_block_mapping(
539540

540541
fake_index_map_args, fake_index_map_kwargs = \
541542
index_map_tree.unflatten([False] * index_map_tree.num_leaves)
542-
debug = api_util.debug_info("pallas_call index_map",
543-
index_map_func, fake_index_map_args,
544-
fake_index_map_kwargs)
543+
debug_info = api_util.debug_info(
544+
"pallas_call index_map",
545+
index_map_func,
546+
fake_index_map_args,
547+
fake_index_map_kwargs,
548+
)
545549
flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
546-
lu.wrap_init(index_map_func, debug_info=debug), index_map_tree)
550+
lu.wrap_init(index_map_func, debug_info=debug_info), index_map_tree
551+
)
547552
with tracing_grid_env(grid, mapped_dims):
548553
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
549554
flat_index_map_fun, index_map_avals
@@ -553,7 +558,7 @@ def to_block_mapping(
553558

554559
if len(unflat_avals) != len(block_shape):
555560
raise ValueError(
556-
f"Index map function {debug.func_src_info} for "
561+
f"Index map function {debug_info.func_src_info} for "
557562
f"{origin} must return "
558563
f"{len(block_shape)} values to match {block_shape=}. "
559564
f"Currently returning {len(unflat_avals)} values:"
@@ -581,14 +586,14 @@ def to_block_mapping(
581586
for i, ov in enumerate(out_avals):
582587
if ov.shape or ov.dtype not in [jnp.int32, jnp.int64]:
583588
raise ValueError(
584-
f"Index map function {debug.func_src_info} for "
589+
f"Index map function {debug_info.func_src_info} for "
585590
f"{origin} must return integer scalars. Output[{i}] has type "
586591
f"{ov}."
587592
)
588593

589594
if consts:
590595
raise ValueError(
591-
f"Index map function {debug.func_src_info} for "
596+
f"Index map function {debug_info.func_src_info} for "
592597
f"{origin} must not capture constants: {consts}"
593598
)
594599

@@ -604,6 +609,7 @@ def to_block_mapping(
604609
),
605610
origin=origin,
606611
pipeline_mode=self.pipeline_mode,
612+
debug=debug,
607613
)
608614
mapping.check_invariants()
609615
return mapping
@@ -645,6 +651,7 @@ class BlockMapping:
645651
origin: OriginStr
646652
transforms: Sequence[MemoryRefTransform] = ()
647653
pipeline_mode: Buffered | None = None
654+
debug: bool = False
648655

649656
def check_invariants(self) -> None:
650657
if not config.enable_checks.value: return
@@ -716,6 +723,24 @@ def has_trivial_window(self):
716723
return False
717724
return True
718725

726+
def __repr__(self):
727+
if self.debug:
728+
return (
729+
f"BlockMapping(block_shape={self.block_shape}, "
730+
f"transformed_block_aval={self.transformed_block_aval}, "
731+
f"index_map_jaxpr={self.index_map_jaxpr}, "
732+
f"index_map_out_tree={self.index_map_out_tree}, "
733+
f"array_shape_dtype={self.array_shape_dtype}, "
734+
f"origin={self.origin}, "
735+
f"transforms={self.transforms}, "
736+
f"pipeline_mode={self.pipeline_mode}, "
737+
f"debug={self.debug})"
738+
)
739+
return f"BlockMapping(block_shape={self.block_shape})"
740+
741+
def __str__(self):
742+
return self.__repr__()
743+
719744

720745
@contextlib.contextmanager
721746
def tracing_grid_env(grid: GridMappingGrid, mapped_dims: tuple[int, ...]):
@@ -780,6 +805,8 @@ class GridMapping:
780805
num_scratch_operands: int
781806
get_grid_indices: Callable | None = None
782807
local_grid_env: Callable | None = None
808+
# Primarily dictates how much debugging information is printed.
809+
debug: bool = False
783810

784811
def check_invariants(self) -> None:
785812
if not config.enable_checks.value: return
@@ -903,6 +930,29 @@ def out_shapes(self) -> Iterable[jax.ShapeDtypeStruct]:
903930
return tuple(
904931
bm.array_shape_dtype for bm in self.block_mappings_output)
905932

933+
def __repr__(self):
934+
if self.debug:
935+
return (
936+
f"GridMapping(grid={self.grid}, grid_names={self.grid_names}, "
937+
f"block_mappings={self.block_mappings}, "
938+
f"index_map_tree={self.index_map_tree}, "
939+
f"index_map_avals={self.index_map_avals}, "
940+
f"vmapped_dims={self.vmapped_dims}, "
941+
f"num_index_operands={self.num_index_operands}, "
942+
f"num_inputs={self.num_inputs}, "
943+
f"num_outputs={self.num_outputs}, "
944+
f"num_scratch_operands={self.num_scratch_operands}, "
945+
f"get_grid_indices={self.get_grid_indices}, "
946+
f"local_grid_env={self.local_grid_env}, "
947+
f"debug={self.debug})"
948+
)
949+
return (
950+
f"GridMapping(grid={self.grid}, block_mappings={self.block_mappings})"
951+
)
952+
953+
def __str__(self):
954+
return self.__repr__()
955+
906956

907957
def _is_valid_grid_dim(dim: int | jax.Array) -> bool:
908958
if isinstance(dim, jax.Array):
@@ -938,6 +988,7 @@ def _convert_block_spec_to_block_mapping(
938988
index_map_tree: tree_util.PyTreeDef,
939989
grid: GridMappingGrid,
940990
mapped_dims: tuple[int, ...],
991+
debug: bool = False,
941992
) -> BlockMapping:
942993
if block_spec is no_block_spec:
943994
block_spec = BlockSpec(None, None)
@@ -948,8 +999,10 @@ def _convert_block_spec_to_block_mapping(
948999
index_map_tree=index_map_tree,
9491000
grid=grid,
9501001
mapped_dims=mapped_dims,
1002+
debug=debug,
9511003
)
9521004

1005+
9531006
index_map_grid_aval = jax_core.ShapedArray((), jnp.int32)
9541007

9551008

@@ -1023,8 +1076,8 @@ def get_grid_mapping(
10231076
out_avals: Sequence[jax_core.AbstractValue],
10241077
out_tree: tree_util.PyTreeDef,
10251078
out_origins: Sequence[OriginStr],
1026-
) -> tuple[tuple[jax_core.AbstractValue, ...],
1027-
GridMapping]:
1079+
debug: bool = False,
1080+
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
10281081
if dynamic_shapes_export_enabled():
10291082
dim_check : Any = jax_core.is_dim
10301083
else:
@@ -1090,6 +1143,7 @@ def get_grid_mapping(
10901143
index_map_tree=index_map_tree,
10911144
grid=grid_mapping_grid, # type: ignore[arg-type]
10921145
mapped_dims=(),
1146+
debug=debug,
10931147
),
10941148
flat_in_specs,
10951149
in_origins[num_flat_scalar_prefetch:],
@@ -1112,6 +1166,7 @@ def get_grid_mapping(
11121166
index_map_tree=index_map_tree,
11131167
grid=grid_mapping_grid, # type: ignore[arg-type]
11141168
mapped_dims=(),
1169+
debug=debug,
11151170
),
11161171
flat_out_specs,
11171172
out_origins,
@@ -1128,6 +1183,7 @@ def get_grid_mapping(
11281183
num_inputs=len(flat_in_specs),
11291184
num_outputs=len(flat_out_specs),
11301185
num_scratch_operands=num_flat_scratch_operands,
1186+
debug=debug,
11311187
)
11321188
grid_mapping.check_invariants()
11331189
in_ref_avals = [bm.ref_aval for bm in in_block_mappings]

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,7 @@ def to_block_mapping(
808808
index_map_tree: tree_util.PyTreeDef,
809809
grid: pallas_core.GridMappingGrid,
810810
mapped_dims: tuple[int, ...],
811+
debug: bool = False,
811812
) -> pallas_core.BlockMapping:
812813
bm = super().to_block_mapping(
813814
origin,
@@ -816,6 +817,7 @@ def to_block_mapping(
816817
index_map_tree=index_map_tree,
817818
grid=grid,
818819
mapped_dims=mapped_dims,
820+
debug=debug,
819821
)
820822
block_inner_aval = bm.block_aval.inner_aval
821823
for t in self.transforms:

jax/_src/pallas/pallas_call.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,10 +1131,10 @@ def _ensure_2d_error_shape(arg):
11311131
retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval,
11321132
*error_memref_aval, *output_aval, *scratch_aval]
11331133
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
1134-
debug = api_util.debug_info("checkify_pallas", checked_kernel_fn,
1134+
debug_info = api_util.debug_info("checkify_pallas", checked_kernel_fn,
11351135
retrace_in_avals, {})
11361136
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
1137-
lu.wrap_init(checked_kernel_fn, debug_info=debug), jaxpr_in_tree)
1137+
lu.wrap_init(checked_kernel_fn, debug_info=debug_info), jaxpr_in_tree)
11381138

11391139
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
11401140
final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
@@ -1146,13 +1146,18 @@ def _ensure_2d_error_shape(arg):
11461146
error_paths, _ = unzip2(tree_util.tree_flatten_with_path(error_block_specs)[0])
11471147
error_origins = tuple(f"errors[{tree_util.keystr(p)}" for p in error_paths)
11481148
error_block_mappings = map(
1149-
partial(
1150-
pallas_core._convert_block_spec_to_block_mapping,
1151-
index_map_avals=grid_mapping.index_map_avals,
1152-
index_map_tree=grid_mapping.index_map_tree,
1153-
grid=grid_mapping.grid,
1154-
mapped_dims=grid_mapping.vmapped_dims),
1155-
error_block_specs, error_origins, shaped_err_avals)
1149+
partial(
1150+
pallas_core._convert_block_spec_to_block_mapping,
1151+
index_map_avals=grid_mapping.index_map_avals,
1152+
index_map_tree=grid_mapping.index_map_tree,
1153+
grid=grid_mapping.grid,
1154+
mapped_dims=grid_mapping.vmapped_dims,
1155+
debug=True,
1156+
),
1157+
error_block_specs,
1158+
error_origins,
1159+
shaped_err_avals,
1160+
)
11561161
input_block_mappings, output_block_mappings = split_list(
11571162
grid_mapping.block_mappings, [num_kernel_inputs,])
11581163
grid_mapping_with_error = grid_mapping.replace(
@@ -1396,7 +1401,9 @@ def _pallas_call_state_discharge_rule(
13961401
index_map_tree=grid_mapping.index_map_tree,
13971402
grid=grid_mapping.grid,
13981403
mapped_dims=grid_mapping.mapped_dims,
1399-
) for ref_aval, block_spec in zip(ref_avals, ref_block_specs)
1404+
debug=debug,
1405+
)
1406+
for ref_aval, block_spec in zip(ref_avals, ref_block_specs)
14001407
]
14011408
in_block_mappings, out_block_mappings = split_list(
14021409
grid_mapping.block_mappings, [grid_mapping.num_inputs]
@@ -1665,8 +1672,14 @@ def wrapped(*args):
16651672
# TODO(necula): check that input_output_aliases is well-formed: no duplicates, etc.
16661673
kernel_args, grid_mapping = pallas_core.get_grid_mapping(
16671674
grid_spec,
1668-
flat_in_avals, in_tree, in_origins,
1669-
flat_out_avals, out_tree, out_origins)
1675+
flat_in_avals,
1676+
in_tree,
1677+
in_origins,
1678+
flat_out_avals,
1679+
out_tree,
1680+
out_origins,
1681+
debug,
1682+
)
16701683
flat_kernel_args, kernel_in_tree = tree_util.tree_flatten(kernel_args)
16711684
flat_kernel_avals = tuple(
16721685
x.ref if isinstance(x, state_types.TransformedRef) else x

0 commit comments

Comments
 (0)