diff --git a/jax/collect_profile.py b/jax/collect_profile.py index 48468bb10064..7af3033d7b85 100644 --- a/jax/collect_profile.py +++ b/jax/collect_profile.py @@ -72,17 +72,25 @@ def collect_profile(port: int, duration_in_ms: int, host: str, "device_tracer_level": device_tracer_level, "python_tracer_level": python_tracer_level, } + IS_GCS_PATH = str(log_dir).startswith("gs://") log_dir_ = pathlib.Path(log_dir if log_dir is not None else tempfile.mkdtemp()) + str_log_dir = log_dir if IS_GCS_PATH else str(log_dir_) _pywrap_profiler_plugin.trace( _strip_addresses(f"{host}:{port}", _GRPC_PREFIX), - str(log_dir_), + str_log_dir, '', True, duration_in_ms, DEFAULT_NUM_TRACING_ATTEMPTS, options, ) - print(f"Dumped profiling information in: {log_dir_}") + print(f"Dumped profiling information in: {str_log_dir}") + # Traces stored on GCS cannot be converted to a Perfetto trace, as JAX doesn't + # directly support GCS paths. + if IS_GCS_PATH: + if not no_perfetto_link: + print("Perfetto link is not supported for GCS paths, skipping creation.") + return # The profiler dumps `xplane.pb` to the logging directory. To upload it to # the Perfetto trace viewer, we need to convert it to a `trace.json` file. # We do this by first finding the `xplane.pb` file, then passing it into diff --git a/tests/profiler_test.py b/tests/profiler_test.py index 14ebdef70267..423f2570d705 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -22,6 +22,7 @@ import threading import time import unittest +import unittest.mock from absl.testing import absltest import pathlib @@ -41,6 +42,7 @@ try: from xprof.convert import _pywrap_profiler_plugin + import jax.collect_profile except ImportError: _pywrap_profiler_plugin = None @@ -470,5 +472,43 @@ def on_profile(): thread_profiler.join() self._check_xspace_pb_exist(logdir) + @unittest.skipIf( + not (portpicker and _pywrap_profiler_plugin), + "Test requires xprof and portpicker") + def test_remote_profiler_gcs_path(self): + port = portpicker.pick_unused_port() + jax.profiler.start_server(port) + + profile_done = threading.Event() + logdir = "gs://mock-test-bucket/test-dir" + # Mock XProf call in collect_profile. + _pywrap_profiler_plugin.trace = unittest.mock.MagicMock() + def on_profile(): + jax.collect_profile(port, 500, logdir, no_perfetto_link=True) + profile_done.set() + + thread_profiler = threading.Thread( + target=on_profile, args=()) + thread_profiler.start() + start_time = time.time() + y = jnp.zeros((5, 5)) + while not profile_done.is_set(): + # The timeout here must be relatively high. The profiler takes a while to + # start up on Cloud TPUs. + if time.time() - start_time > 30: + raise RuntimeError("Profile did not complete in 30s") + y = jnp.dot(y, y) + jax.profiler.stop_server() + thread_profiler.join() + _pywrap_profiler_plugin.trace.assert_called_once_with( + unittest.mock.ANY, + logdir, + unittest.mock.ANY, + unittest.mock.ANY, + unittest.mock.ANY, + unittest.mock.ANY, + unittest.mock.ANY, + ) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())