diff --git a/drjax/_src/impls.py b/drjax/_src/impls.py index 7b318a7..c43d308 100644 --- a/drjax/_src/impls.py +++ b/drjax/_src/impls.py @@ -322,7 +322,7 @@ def _constrain_at_placement_with_slices_like(x): lambda arr: jax.sharding.reshard( arr, jax.sharding.NamedSharding( - mesh, P(placement, *arr.sharding.spec[1:]) + mesh, spec=P(placement, *jax.typeof(arr).sharding.spec[1:]) ), ), result, diff --git a/drjax/_src/impls_test.py b/drjax/_src/impls_test.py index ff91f6a..30c19b0 100644 --- a/drjax/_src/impls_test.py +++ b/drjax/_src/impls_test.py @@ -13,13 +13,26 @@ # limitations under the License. from absl.testing import absltest +from absl.testing import parameterized import chex from drjax._src import impls import jax from jax import numpy as jnp +from jax.sharding import AxisType, NamedSharding, PartitionSpec # pylint: disable=g-multiple-import +import numpy as np -class ImplsTest(absltest.TestCase): +def create_mesh( + axis_type: AxisType, +) -> jax.sharding.Mesh: + return jax.sharding.Mesh( + np.asarray(jax.devices()).reshape(2, 4), + axis_names=('clients', 'data'), + axis_types=(axis_type, axis_type), + ) + + +class ImplsTest(parameterized.TestCase): def setUp(self): super().setUp() @@ -54,30 +67,39 @@ def temp_sens_example(m, t): temp_sens_example(measurements, jnp.median(measurements)), 0.5 ) - def test_runs_fake_training(self): - comp_factory = impls.PlacedComputations( - placements_to_n_elements=self._placements, - ) - def _reduce_sequence(sequence, model): - - for _ in range(sequence.shape[0]): - model += 1 - return model - - def fake_training(model, data): - model_at_clients = comp_factory.broadcast_to_placement(model, 'clients') - trained_models = comp_factory.map_to_placement( - _reduce_sequence, (data, model_at_clients), 'clients' + @parameterized.product( + mesh_axes_type=[AxisType.Auto, AxisType.Explicit], + ) + def test_runs_grad_training(self, mesh_axes_type): + mesh = create_mesh(mesh_axes_type) + with jax.set_mesh(mesh): + comp_factory = impls.PlacedComputations( + placements_to_n_elements=self._placements, ) - return comp_factory.mean_from_placement(trained_models) - clients_data = jnp.ones( - shape=[self._placements['clients'], self._sequence_length] - ) + def update(model, x): + return jax.value_and_grad(lambda m, x: jnp.sum(m * jnp.square(x)))( + model, x + ) + + def test_training(model, data): + model_at_clients = comp_factory.broadcast_to_placement(model, 'clients') + grads, _ = comp_factory.map_to_placement( + update, (model_at_clients, data), 'clients' + ) + return comp_factory.mean_from_placement(grads) + + clients_data = jax.device_put( + jnp.ones(shape=(self._placements['clients'],), dtype=jnp.float32), + device=NamedSharding(mesh, PartitionSpec('clients')), + ) + model = jax.device_put([0.0], device=NamedSharding(mesh, PartitionSpec())) + self.assertEqual(jax.jit(test_training)(model, clients_data), 0.0) - model = jnp.array(0.0) - self.assertEqual(fake_training(model, clients_data), 10.0) +# This allows us to test sharding behavior across multiple devices. +def setUpModule(): + chex.set_n_cpu_devices(8) if __name__ == '__main__':