Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion drjax/_src/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def _constrain_at_placement_with_slices_like(x):
lambda arr: jax.sharding.reshard(
arr,
jax.sharding.NamedSharding(
mesh, P(placement, *arr.sharding.spec[1:])
mesh, spec=P(placement, *jax.typeof(arr).sharding.spec[1:])
),
),
result,
Expand Down
64 changes: 43 additions & 21 deletions drjax/_src/impls_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,26 @@
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized
import chex
from drjax._src import impls
import jax
from jax import numpy as jnp
from jax.sharding import AxisType, NamedSharding, PartitionSpec # pylint: disable=g-multiple-import
import numpy as np


class ImplsTest(absltest.TestCase):
def create_mesh(
axis_type: AxisType,
) -> jax.sharding.Mesh:
return jax.sharding.Mesh(
np.asarray(jax.devices()).reshape(2, 4),
axis_names=('clients', 'data'),
axis_types=(axis_type, axis_type),
)


class ImplsTest(parameterized.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -54,30 +67,39 @@ def temp_sens_example(m, t):
temp_sens_example(measurements, jnp.median(measurements)), 0.5
)

def test_runs_fake_training(self):
comp_factory = impls.PlacedComputations(
placements_to_n_elements=self._placements,
)
def _reduce_sequence(sequence, model):

for _ in range(sequence.shape[0]):
model += 1
return model

def fake_training(model, data):
model_at_clients = comp_factory.broadcast_to_placement(model, 'clients')
trained_models = comp_factory.map_to_placement(
_reduce_sequence, (data, model_at_clients), 'clients'
@parameterized.product(
mesh_axes_type=[AxisType.Auto, AxisType.Explicit],
)
def test_runs_grad_training(self, mesh_axes_type):
mesh = create_mesh(mesh_axes_type)
with jax.set_mesh(mesh):
comp_factory = impls.PlacedComputations(
placements_to_n_elements=self._placements,
)
return comp_factory.mean_from_placement(trained_models)

clients_data = jnp.ones(
shape=[self._placements['clients'], self._sequence_length]
)
def update(model, x):
return jax.value_and_grad(lambda m, x: jnp.sum(m * jnp.square(x)))(
model, x
)

def test_training(model, data):
model_at_clients = comp_factory.broadcast_to_placement(model, 'clients')
grads, _ = comp_factory.map_to_placement(
update, (model_at_clients, data), 'clients'
)
return comp_factory.mean_from_placement(grads)

clients_data = jax.device_put(
jnp.ones(shape=(self._placements['clients'],), dtype=jnp.float32),
device=NamedSharding(mesh, PartitionSpec('clients')),
)
model = jax.device_put([0.0], device=NamedSharding(mesh, PartitionSpec()))
self.assertEqual(jax.jit(test_training)(model, clients_data), 0.0)

model = jnp.array(0.0)

self.assertEqual(fake_training(model, clients_data), 10.0)
# This allows us to test sharding behavior across multiple devices.
def setUpModule():
chex.set_n_cpu_devices(8)


if __name__ == '__main__':
Expand Down
Loading