|
21 | 21 | import threading
|
22 | 22 | import time
|
23 | 23 | import unittest
|
| 24 | +import unittest.mock |
24 | 25 | from absl.testing import absltest
|
25 | 26 | import pathlib
|
26 | 27 |
|
|
39 | 40 |
|
40 | 41 | try:
|
41 | 42 | from xprof.convert import _pywrap_profiler_plugin
|
| 43 | + import jax.collect_profile |
42 | 44 | except ImportError:
|
43 | 45 | _pywrap_profiler_plugin = None
|
44 | 46 |
|
@@ -435,5 +437,43 @@ def on_profile():
|
435 | 437 | thread_profiler.join()
|
436 | 438 | self._check_xspace_pb_exist(logdir)
|
437 | 439 |
|
| 440 | + @unittest.skipIf( |
| 441 | + not (portpicker and _pywrap_profiler_plugin), |
| 442 | + "Test requires xprof and portpicker") |
| 443 | + def test_remote_profiler_gcs_path(self): |
| 444 | + port = portpicker.pick_unused_port() |
| 445 | + jax.profiler.start_server(port) |
| 446 | + |
| 447 | + profile_done = threading.Event() |
| 448 | + logdir = "gs://mock-test-bucket/test-dir" |
| 449 | + # Mock XProf call in collect_profile. |
| 450 | + _pywrap_profiler_plugin.trace = unittest.mock.MagicMock() |
| 451 | + def on_profile(): |
| 452 | + jax.collect_profile(port, 500, logdir, no_perfetto_link=True) |
| 453 | + profile_done.set() |
| 454 | + |
| 455 | + thread_profiler = threading.Thread( |
| 456 | + target=on_profile, args=()) |
| 457 | + thread_profiler.start() |
| 458 | + start_time = time.time() |
| 459 | + y = jnp.zeros((5, 5)) |
| 460 | + while not profile_done.is_set(): |
| 461 | + # The timeout here must be relatively high. The profiler takes a while to |
| 462 | + # start up on Cloud TPUs. |
| 463 | + if time.time() - start_time > 30: |
| 464 | + raise RuntimeError("Profile did not complete in 30s") |
| 465 | + y = jnp.dot(y, y) |
| 466 | + jax.profiler.stop_server() |
| 467 | + thread_profiler.join() |
| 468 | + _pywrap_profiler_plugin.trace.assert_called_once_with( |
| 469 | + unittest.mock.ANY, |
| 470 | + logdir, |
| 471 | + unittest.mock.ANY, |
| 472 | + unittest.mock.ANY, |
| 473 | + unittest.mock.ANY, |
| 474 | + unittest.mock.ANY, |
| 475 | + unittest.mock.ANY, |
| 476 | + ) |
| 477 | + |
438 | 478 | if __name__ == "__main__":
|
439 | 479 | absltest.main(testLoader=jtu.JaxTestLoader())
|
0 commit comments