17
17
18
18
from absl import logging
19
19
from absl .testing import absltest
20
-
21
20
from jax import version
22
21
from jax ._src import compiler
23
22
from jax ._src import config
@@ -36,18 +35,14 @@ class XlaBridgeTest(jtu.JaxTestCase):
36
35
def test_set_device_assignment_no_partition (self ):
37
36
compile_options = compiler .get_compile_options (
38
37
num_replicas = 4 , num_partitions = 1 , device_assignment = [0 , 1 , 2 , 3 ])
39
- expected_device_assignment = ("Computations: 1 Replicas: 4\n Computation 0: "
40
- "0 1 2 3 \n " )
41
- self .assertEqual (compile_options .device_assignment .__repr__ (),
42
- expected_device_assignment )
38
+ self .assertEqual (compile_options .device_assignment .replica_count (), 4 )
39
+ self .assertEqual (compile_options .device_assignment .computation_count (), 1 )
43
40
44
41
def test_set_device_assignment_with_partition (self ):
45
42
compile_options = compiler .get_compile_options (
46
43
num_replicas = 2 , num_partitions = 2 , device_assignment = [[0 , 1 ], [2 , 3 ]])
47
- expected_device_assignment = ("Computations: 2 Replicas: 2\n Computation 0: "
48
- "0 2 \n Computation 1: 1 3 \n " )
49
- self .assertEqual (compile_options .device_assignment .__repr__ (),
50
- expected_device_assignment )
44
+ self .assertEqual (compile_options .device_assignment .replica_count (), 2 )
45
+ self .assertEqual (compile_options .device_assignment .computation_count (), 2 )
51
46
52
47
def test_set_fdo_profile (self ):
53
48
compile_options = compiler .get_compile_options (
0 commit comments