Skip to content

Commit 28c31b8

Browse files
toli-yGoogle-ML-Automation
authored andcommitted
Update JAX test to not rely on ToString and instead check the Device Assignment values.
PiperOrigin-RevId: 769417940
1 parent b22be86 commit 28c31b8

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

tests/xla_bridge_test.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from absl import logging
1919
from absl.testing import absltest
20-
2120
from jax import version
2221
from jax._src import compiler
2322
from jax._src import config
@@ -36,18 +35,14 @@ class XlaBridgeTest(jtu.JaxTestCase):
3635
def test_set_device_assignment_no_partition(self):
3736
compile_options = compiler.get_compile_options(
3837
num_replicas=4, num_partitions=1, device_assignment=[0, 1, 2, 3])
39-
expected_device_assignment = ("Computations: 1 Replicas: 4\nComputation 0: "
40-
"0 1 2 3 \n")
41-
self.assertEqual(compile_options.device_assignment.__repr__(),
42-
expected_device_assignment)
38+
self.assertEqual(compile_options.device_assignment.replica_count(), 4)
39+
self.assertEqual(compile_options.device_assignment.computation_count(), 1)
4340

4441
def test_set_device_assignment_with_partition(self):
4542
compile_options = compiler.get_compile_options(
4643
num_replicas=2, num_partitions=2, device_assignment=[[0, 1], [2, 3]])
47-
expected_device_assignment = ("Computations: 2 Replicas: 2\nComputation 0: "
48-
"0 2 \nComputation 1: 1 3 \n")
49-
self.assertEqual(compile_options.device_assignment.__repr__(),
50-
expected_device_assignment)
44+
self.assertEqual(compile_options.device_assignment.replica_count(), 2)
45+
self.assertEqual(compile_options.device_assignment.computation_count(), 2)
5146

5247
def test_set_fdo_profile(self):
5348
compile_options = compiler.get_compile_options(

0 commit comments

Comments
 (0)