Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions pathwaysutils/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def start_trace(
*,
create_perfetto_link: bool = False,
create_perfetto_trace: bool = False,
profiler_options: jax.profiler.ProfileOptions | None = None, # pylint: disable=unused-argument
) -> None:
"""Starts a profiler trace.

Expand Down Expand Up @@ -131,6 +132,8 @@ def start_trace(
want to generate a Perfetto-compatible trace without blocking the process.
This feature is experimental for Pathways on Cloud and may not be fully
supported.
profiler_options: Profiler options to configure the profiler for collection.
Options are not currently supported and ignored.
"""
if not str(log_dir).startswith("gs://"):
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")
Expand Down Expand Up @@ -270,11 +273,17 @@ def monkey_patch_jax():

def start_trace_patch(
log_dir,
create_perfetto_link: bool = False, # pylint: disable=unused-argument
create_perfetto_trace: bool = False, # pylint: disable=unused-argument
create_perfetto_link: bool = False,
create_perfetto_trace: bool = False,
profiler_options: jax.profiler.ProfileOptions | None = None, # pylint: disable=unused-argument
) -> None:
_logger.debug("jax.profile.start_trace patched with pathways' start_trace")
return start_trace(log_dir)
return start_trace(
log_dir,
create_perfetto_link=create_perfetto_link,
create_perfetto_trace=create_perfetto_trace,
profiler_options=profiler_options,
)

jax.profiler.start_trace = start_trace_patch

Expand Down
7 changes: 6 additions & 1 deletion pathwaysutils/test/profiling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,12 @@ def test_monkey_patch_jax(self):
profiling, "start_trace", autospec=True
) as mock_pw_start_trace:
jax.profiler.start_trace("gs://bucket/dir")
mock_pw_start_trace.assert_called_once_with("gs://bucket/dir")
mock_pw_start_trace.assert_called_once_with(
"gs://bucket/dir",
create_perfetto_link=False,
create_perfetto_trace=False,
profiler_options=None,
)

with mock.patch.object(
profiling, "stop_trace", autospec=True
Expand Down