We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4fa3cd9 commit 0ed0fb7Copy full SHA for 0ed0fb7
jax/_src/sharding_impls.py
@@ -898,7 +898,11 @@ def parse_flatten_op_sharding(
898
while dim_size > 1:
899
axis = next(mesh_axis)
900
axis_size = mesh_shape[axis]
901
- assert dim_size % axis_size == 0
+ if dim_size % axis_size != 0:
902
+ raise ValueError(
903
+ f'{shape=} is incompatible with {mesh_shape=}: '
904
+ f'{dim_size=} is not divisible by {axis_size=}.'
905
+ )
906
dim_size //= axis_size
907
dim_partitions.append(axis)
908
partitions.append(tuple(dim_partitions))
0 commit comments