Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions jax/collect_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,25 @@ def collect_profile(port: int, duration_in_ms: int, host: str,
"device_tracer_level": device_tracer_level,
"python_tracer_level": python_tracer_level,
}
IS_GCS_PATH = str(log_dir).startswith("gs://")
log_dir_ = pathlib.Path(log_dir if log_dir is not None else tempfile.mkdtemp())
str_log_dir = log_dir if IS_GCS_PATH else str(log_dir_)
_pywrap_profiler_plugin.trace(
_strip_addresses(f"{host}:{port}", _GRPC_PREFIX),
str(log_dir_),
str_log_dir,
'',
True,
duration_in_ms,
DEFAULT_NUM_TRACING_ATTEMPTS,
options,
)
print(f"Dumped profiling information in: {log_dir_}")
print(f"Dumped profiling information in: {str_log_dir}")
# Traces stored on GCS cannot be converted to a Perfetto trace, as JAX doesn't
# directly support GCS paths.
if IS_GCS_PATH:
if not no_perfetto_link:
print("Perfetto link is not supported for GCS paths, skipping creation.")
return
# The profiler dumps `xplane.pb` to the logging directory. To upload it to
# the Perfetto trace viewer, we need to convert it to a `trace.json` file.
# We do this by first finding the `xplane.pb` file, then passing it into
Expand Down
40 changes: 40 additions & 0 deletions tests/profiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import threading
import time
import unittest
import unittest.mock
from absl.testing import absltest
import pathlib

Expand All @@ -41,6 +42,7 @@

try:
from xprof.convert import _pywrap_profiler_plugin
import jax.collect_profile
except ImportError:
_pywrap_profiler_plugin = None

Expand Down Expand Up @@ -470,5 +472,43 @@ def on_profile():
thread_profiler.join()
self._check_xspace_pb_exist(logdir)

@unittest.skipIf(
not (portpicker and _pywrap_profiler_plugin),
"Test requires xprof and portpicker")
def test_remote_profiler_gcs_path(self):
port = portpicker.pick_unused_port()
jax.profiler.start_server(port)

profile_done = threading.Event()
logdir = "gs://mock-test-bucket/test-dir"
# Mock XProf call in collect_profile.
_pywrap_profiler_plugin.trace = unittest.mock.MagicMock()
def on_profile():
jax.collect_profile(port, 500, logdir, no_perfetto_link=True)
profile_done.set()

thread_profiler = threading.Thread(
target=on_profile, args=())
thread_profiler.start()
start_time = time.time()
y = jnp.zeros((5, 5))
while not profile_done.is_set():
# The timeout here must be relatively high. The profiler takes a while to
# start up on Cloud TPUs.
if time.time() - start_time > 30:
raise RuntimeError("Profile did not complete in 30s")
y = jnp.dot(y, y)
jax.profiler.stop_server()
thread_profiler.join()
_pywrap_profiler_plugin.trace.assert_called_once_with(
unittest.mock.ANY,
logdir,
unittest.mock.ANY,
unittest.mock.ANY,
unittest.mock.ANY,
unittest.mock.ANY,
unittest.mock.ANY,
)

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
Loading