Skip to content

Commit 30e426d

Browse files
Matt-HurdGoogle-ML-Automation
authored andcommitted
Add support for GCS paths in jax.collect_profile
PiperOrigin-RevId: 813873034
1 parent bf86259 commit 30e426d

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

jax/collect_profile.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,25 @@ def collect_profile(port: int, duration_in_ms: int, host: str,
7272
"device_tracer_level": device_tracer_level,
7373
"python_tracer_level": python_tracer_level,
7474
}
75+
IS_GCS_PATH = str(log_dir).startswith("gs://")
7576
log_dir_ = pathlib.Path(log_dir if log_dir is not None else tempfile.mkdtemp())
77+
str_log_dir = log_dir if IS_GCS_PATH else str(log_dir_)
7678
_pywrap_profiler_plugin.trace(
7779
_strip_addresses(f"{host}:{port}", _GRPC_PREFIX),
78-
str(log_dir_),
80+
str_log_dir,
7981
'',
8082
True,
8183
duration_in_ms,
8284
DEFAULT_NUM_TRACING_ATTEMPTS,
8385
options,
8486
)
85-
print(f"Dumped profiling information in: {log_dir_}")
87+
print(f"Dumped profiling information in: {str_log_dir}")
88+
# Traces stored on GCS cannot be converted to a Perfetto trace, as JAX doesn't
89+
# directly support GCS paths.
90+
if IS_GCS_PATH:
91+
if not no_perfetto_link:
92+
print("Perfetto link is not supported for GCS paths, skipping creation.")
93+
return
8694
# The profiler dumps `xplane.pb` to the logging directory. To upload it to
8795
# the Perfetto trace viewer, we need to convert it to a `trace.json` file.
8896
# We do this by first finding the `xplane.pb` file, then passing it into

tests/profiler_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import threading
2222
import time
2323
import unittest
24+
import unittest.mock
2425
from absl.testing import absltest
2526
import pathlib
2627

@@ -39,6 +40,7 @@
3940

4041
try:
4142
from xprof.convert import _pywrap_profiler_plugin
43+
import jax.collect_profile
4244
except ImportError:
4345
_pywrap_profiler_plugin = None
4446

@@ -435,5 +437,43 @@ def on_profile():
435437
thread_profiler.join()
436438
self._check_xspace_pb_exist(logdir)
437439

440+
@unittest.skipIf(
441+
not (portpicker and _pywrap_profiler_plugin),
442+
"Test requires xprof and portpicker")
443+
def test_remote_profiler_gcs_path(self):
444+
port = portpicker.pick_unused_port()
445+
jax.profiler.start_server(port)
446+
447+
profile_done = threading.Event()
448+
logdir = "gs://mock-test-bucket/test-dir"
449+
# Mock XProf call in collect_profile.
450+
_pywrap_profiler_plugin.trace = unittest.mock.MagicMock()
451+
def on_profile():
452+
jax.collect_profile(port, 500, logdir, no_perfetto_link=True)
453+
profile_done.set()
454+
455+
thread_profiler = threading.Thread(
456+
target=on_profile, args=())
457+
thread_profiler.start()
458+
start_time = time.time()
459+
y = jnp.zeros((5, 5))
460+
while not profile_done.is_set():
461+
# The timeout here must be relatively high. The profiler takes a while to
462+
# start up on Cloud TPUs.
463+
if time.time() - start_time > 30:
464+
raise RuntimeError("Profile did not complete in 30s")
465+
y = jnp.dot(y, y)
466+
jax.profiler.stop_server()
467+
thread_profiler.join()
468+
_pywrap_profiler_plugin.trace.assert_called_once_with(
469+
unittest.mock.ANY,
470+
logdir,
471+
unittest.mock.ANY,
472+
unittest.mock.ANY,
473+
unittest.mock.ANY,
474+
unittest.mock.ANY,
475+
unittest.mock.ANY,
476+
)
477+
438478
if __name__ == "__main__":
439479
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)