@@ -499,6 +499,7 @@ def to_block_mapping(
499
499
index_map_tree : tree_util .PyTreeDef ,
500
500
grid : GridMappingGrid ,
501
501
mapped_dims : tuple [int , ...],
502
+ debug : bool = False ,
502
503
) -> BlockMapping :
503
504
if self .index_map is None :
504
505
index_map_func = default_index_map (len (array_aval .shape ))
@@ -539,11 +540,15 @@ def to_block_mapping(
539
540
540
541
fake_index_map_args , fake_index_map_kwargs = \
541
542
index_map_tree .unflatten ([False ] * index_map_tree .num_leaves )
542
- debug = api_util .debug_info ("pallas_call index_map" ,
543
- index_map_func , fake_index_map_args ,
544
- fake_index_map_kwargs )
543
+ debug_info = api_util .debug_info (
544
+ "pallas_call index_map" ,
545
+ index_map_func ,
546
+ fake_index_map_args ,
547
+ fake_index_map_kwargs ,
548
+ )
545
549
flat_index_map_fun , index_map_out_tree_thunk = api_util .flatten_fun (
546
- lu .wrap_init (index_map_func , debug_info = debug ), index_map_tree )
550
+ lu .wrap_init (index_map_func , debug_info = debug_info ), index_map_tree
551
+ )
547
552
with tracing_grid_env (grid , mapped_dims ):
548
553
jaxpr , out_avals , consts , () = pe .trace_to_jaxpr_dynamic (
549
554
flat_index_map_fun , index_map_avals
@@ -553,7 +558,7 @@ def to_block_mapping(
553
558
554
559
if len (unflat_avals ) != len (block_shape ):
555
560
raise ValueError (
556
- f"Index map function { debug .func_src_info } for "
561
+ f"Index map function { debug_info .func_src_info } for "
557
562
f"{ origin } must return "
558
563
f"{ len (block_shape )} values to match { block_shape = } . "
559
564
f"Currently returning { len (unflat_avals )} values:"
@@ -581,14 +586,14 @@ def to_block_mapping(
581
586
for i , ov in enumerate (out_avals ):
582
587
if ov .shape or ov .dtype not in [jnp .int32 , jnp .int64 ]:
583
588
raise ValueError (
584
- f"Index map function { debug .func_src_info } for "
589
+ f"Index map function { debug_info .func_src_info } for "
585
590
f"{ origin } must return integer scalars. Output[{ i } ] has type "
586
591
f"{ ov } ."
587
592
)
588
593
589
594
if consts :
590
595
raise ValueError (
591
- f"Index map function { debug .func_src_info } for "
596
+ f"Index map function { debug_info .func_src_info } for "
592
597
f"{ origin } must not capture constants: { consts } "
593
598
)
594
599
@@ -604,6 +609,7 @@ def to_block_mapping(
604
609
),
605
610
origin = origin ,
606
611
pipeline_mode = self .pipeline_mode ,
612
+ debug = debug ,
607
613
)
608
614
mapping .check_invariants ()
609
615
return mapping
@@ -645,6 +651,7 @@ class BlockMapping:
645
651
origin : OriginStr
646
652
transforms : Sequence [MemoryRefTransform ] = ()
647
653
pipeline_mode : Buffered | None = None
654
+ debug : bool = False
648
655
649
656
def check_invariants (self ) -> None :
650
657
if not config .enable_checks .value : return
@@ -716,6 +723,24 @@ def has_trivial_window(self):
716
723
return False
717
724
return True
718
725
726
+ def __repr__ (self ):
727
+ if self .debug :
728
+ return (
729
+ f"BlockMapping(block_shape={ self .block_shape } , "
730
+ f"transformed_block_aval={ self .transformed_block_aval } , "
731
+ f"index_map_jaxpr={ self .index_map_jaxpr } , "
732
+ f"index_map_out_tree={ self .index_map_out_tree } , "
733
+ f"array_shape_dtype={ self .array_shape_dtype } , "
734
+ f"origin={ self .origin } , "
735
+ f"transforms={ self .transforms } , "
736
+ f"pipeline_mode={ self .pipeline_mode } , "
737
+ f"debug={ self .debug } )"
738
+ )
739
+ return f"BlockMapping(block_shape={ self .block_shape } )"
740
+
741
+ def __str__ (self ):
742
+ return self .__repr__ ()
743
+
719
744
720
745
@contextlib .contextmanager
721
746
def tracing_grid_env (grid : GridMappingGrid , mapped_dims : tuple [int , ...]):
@@ -780,6 +805,8 @@ class GridMapping:
780
805
num_scratch_operands : int
781
806
get_grid_indices : Callable | None = None
782
807
local_grid_env : Callable | None = None
808
+ # Primarily dictates how much debugging information is printed.
809
+ debug : bool = False
783
810
784
811
def check_invariants (self ) -> None :
785
812
if not config .enable_checks .value : return
@@ -903,6 +930,29 @@ def out_shapes(self) -> Iterable[jax.ShapeDtypeStruct]:
903
930
return tuple (
904
931
bm .array_shape_dtype for bm in self .block_mappings_output )
905
932
933
+ def __repr__ (self ):
934
+ if self .debug :
935
+ return (
936
+ f"GridMapping(grid={ self .grid } , grid_names={ self .grid_names } , "
937
+ f"block_mappings={ self .block_mappings } , "
938
+ f"index_map_tree={ self .index_map_tree } , "
939
+ f"index_map_avals={ self .index_map_avals } , "
940
+ f"vmapped_dims={ self .vmapped_dims } , "
941
+ f"num_index_operands={ self .num_index_operands } , "
942
+ f"num_inputs={ self .num_inputs } , "
943
+ f"num_outputs={ self .num_outputs } , "
944
+ f"num_scratch_operands={ self .num_scratch_operands } , "
945
+ f"get_grid_indices={ self .get_grid_indices } , "
946
+ f"local_grid_env={ self .local_grid_env } , "
947
+ f"debug={ self .debug } )"
948
+ )
949
+ return (
950
+ f"GridMapping(grid={ self .grid } , block_mappings={ self .block_mappings } )"
951
+ )
952
+
953
+ def __str__ (self ):
954
+ return self .__repr__ ()
955
+
906
956
907
957
def _is_valid_grid_dim (dim : int | jax .Array ) -> bool :
908
958
if isinstance (dim , jax .Array ):
@@ -938,6 +988,7 @@ def _convert_block_spec_to_block_mapping(
938
988
index_map_tree : tree_util .PyTreeDef ,
939
989
grid : GridMappingGrid ,
940
990
mapped_dims : tuple [int , ...],
991
+ debug : bool = False ,
941
992
) -> BlockMapping :
942
993
if block_spec is no_block_spec :
943
994
block_spec = BlockSpec (None , None )
@@ -948,8 +999,10 @@ def _convert_block_spec_to_block_mapping(
948
999
index_map_tree = index_map_tree ,
949
1000
grid = grid ,
950
1001
mapped_dims = mapped_dims ,
1002
+ debug = debug ,
951
1003
)
952
1004
1005
+
953
1006
index_map_grid_aval = jax_core .ShapedArray ((), jnp .int32 )
954
1007
955
1008
@@ -1023,8 +1076,8 @@ def get_grid_mapping(
1023
1076
out_avals : Sequence [jax_core .AbstractValue ],
1024
1077
out_tree : tree_util .PyTreeDef ,
1025
1078
out_origins : Sequence [OriginStr ],
1026
- ) -> tuple [ tuple [ jax_core . AbstractValue , ...] ,
1027
- GridMapping ]:
1079
+ debug : bool = False ,
1080
+ ) -> tuple [ tuple [ jax_core . AbstractValue , ...], GridMapping ]:
1028
1081
if dynamic_shapes_export_enabled ():
1029
1082
dim_check : Any = jax_core .is_dim
1030
1083
else :
@@ -1090,6 +1143,7 @@ def get_grid_mapping(
1090
1143
index_map_tree = index_map_tree ,
1091
1144
grid = grid_mapping_grid , # type: ignore[arg-type]
1092
1145
mapped_dims = (),
1146
+ debug = debug ,
1093
1147
),
1094
1148
flat_in_specs ,
1095
1149
in_origins [num_flat_scalar_prefetch :],
@@ -1112,6 +1166,7 @@ def get_grid_mapping(
1112
1166
index_map_tree = index_map_tree ,
1113
1167
grid = grid_mapping_grid , # type: ignore[arg-type]
1114
1168
mapped_dims = (),
1169
+ debug = debug ,
1115
1170
),
1116
1171
flat_out_specs ,
1117
1172
out_origins ,
@@ -1128,6 +1183,7 @@ def get_grid_mapping(
1128
1183
num_inputs = len (flat_in_specs ),
1129
1184
num_outputs = len (flat_out_specs ),
1130
1185
num_scratch_operands = num_flat_scratch_operands ,
1186
+ debug = debug ,
1131
1187
)
1132
1188
grid_mapping .check_invariants ()
1133
1189
in_ref_avals = [bm .ref_aval for bm in in_block_mappings ]
0 commit comments