Skip to content

Commit 0ed0fb7

Browse files
marksandler2Google-ML-Automation
authored andcommitted
Adds a debugging message to assert, otherwise the error is pretty cryptic.
PiperOrigin-RevId: 747657234
1 parent 4fa3cd9 commit 0ed0fb7

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

jax/_src/sharding_impls.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,11 @@ def parse_flatten_op_sharding(
898898
while dim_size > 1:
899899
axis = next(mesh_axis)
900900
axis_size = mesh_shape[axis]
901-
assert dim_size % axis_size == 0
901+
if dim_size % axis_size != 0:
902+
raise ValueError(
903+
f'{shape=} is incompatible with {mesh_shape=}: '
904+
f'{dim_size=} is not divisible by {axis_size=}.'
905+
)
902906
dim_size //= axis_size
903907
dim_partitions.append(axis)
904908
partitions.append(tuple(dim_partitions))

0 commit comments

Comments
 (0)