Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pathwaysutils/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,14 @@ def start_trace_patch(
)

jax.profiler.start_trace = start_trace_patch
jax._src.profiler.start_trace = start_trace_patch # pylint: disable=protected-access

def stop_trace_patch() -> None:
_logger.debug("jax.profile.stop_trace patched with pathways' stop_trace")
return stop_trace()

jax.profiler.stop_trace = stop_trace_patch
jax._src.profiler.stop_trace = stop_trace_patch # pylint: disable=protected-access

def start_server_patch(port: int):
_logger.debug(
Expand Down
127 changes: 85 additions & 42 deletions pathwaysutils/test/profiling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,53 +305,86 @@ def test_stop_server_does_nothing_if_server_exists(self):
profiling.start_server(9000)
profiling.stop_server() # Should not raise

def test_monkey_patch_jax(self):
original_jax_start_trace = jax.profiler.start_trace
original_jax_stop_trace = jax.profiler.stop_trace
original_jax_start_server = jax.profiler.start_server
original_jax_stop_server = jax.profiler.stop_server
def _setup_monkey_patch(self):
"""Saves originals, applies monkey patch, and sets up mocks."""
targets = [
(jax.profiler, "start_trace"),
(jax.profiler, "stop_trace"),
(jax.profiler, "start_server"),
(jax.profiler, "stop_server"),
(jax._src.profiler, "start_trace"),
(jax._src.profiler, "stop_trace"),
]
original_jax_funcs = {}
for module, func_name in targets:
original_func = getattr(module, func_name)
original_jax_funcs[(module, func_name)] = original_func
self.addCleanup(setattr, module, func_name, original_func)

profiling.monkey_patch_jax()

self.assertNotEqual(jax.profiler.start_trace, original_jax_start_trace)
self.assertNotEqual(jax.profiler.stop_trace, original_jax_stop_trace)
self.assertNotEqual(jax.profiler.start_server, original_jax_start_server)
self.assertNotEqual(jax.profiler.stop_server, original_jax_stop_server)

with mock.patch.object(
profiling, "start_trace", autospec=True
) as mock_pw_start_trace:
jax.profiler.start_trace("gs://bucket/dir")
mock_pw_start_trace.assert_called_once_with(
"gs://bucket/dir",
create_perfetto_link=False,
create_perfetto_trace=False,
profiler_options=None,
for module, func_name in targets:
self.assertNotEqual(
getattr(module, func_name),
original_jax_funcs[(module, func_name)],
)

with mock.patch.object(
profiling, "stop_trace", autospec=True
) as mock_pw_stop_trace:
jax.profiler.stop_trace()
mock_pw_stop_trace.assert_called_once()

with mock.patch.object(
profiling, "start_server", autospec=True
) as mock_pw_start_server:
jax.profiler.start_server(1234)
mock_pw_start_server.assert_called_once_with(1234)

with mock.patch.object(
profiling, "stop_server", autospec=True
) as mock_pw_stop_server:
jax.profiler.stop_server()
mock_pw_stop_server.assert_called_once()

# Restore original jax functions
jax.profiler.start_trace = original_jax_start_trace
jax.profiler.stop_trace = original_jax_stop_trace
jax.profiler.start_server = original_jax_start_server
jax.profiler.stop_server = original_jax_stop_server
mocks = {
"start_trace": self.enter_context(
mock.patch.object(profiling, "start_trace", autospec=True)
),
"stop_trace": self.enter_context(
mock.patch.object(profiling, "stop_trace", autospec=True)
),
"start_server": self.enter_context(
mock.patch.object(profiling, "start_server", autospec=True)
),
"stop_server": self.enter_context(
mock.patch.object(profiling, "stop_server", autospec=True)
),
}
return mocks

@parameterized.named_parameters(
dict(testcase_name="jax_profiler", profiler_module=jax.profiler),
dict(testcase_name="jax_src_profiler", profiler_module=jax._src.profiler),
)
def test_monkey_patched_start_trace(self, profiler_module):
mocks = self._setup_monkey_patch()

profiler_module.start_trace("gs://bucket/dir")

mocks["start_trace"].assert_called_once_with(
"gs://bucket/dir",
create_perfetto_link=False,
create_perfetto_trace=False,
profiler_options=None,
)

@parameterized.named_parameters(
dict(testcase_name="jax_profiler", profiler_module=jax.profiler),
dict(testcase_name="jax_src_profiler", profiler_module=jax._src.profiler),
)
def test_monkey_patched_stop_trace(self, profiler_module):
mocks = self._setup_monkey_patch()

profiler_module.stop_trace()

mocks["stop_trace"].assert_called_once()

def test_monkey_patched_start_server(self):
mocks = self._setup_monkey_patch()

jax.profiler.start_server(1234)

mocks["start_server"].assert_called_once_with(1234)

def test_monkey_patched_stop_server(self):
mocks = self._setup_monkey_patch()

jax.profiler.stop_server()

mocks["stop_server"].assert_called_once()

def test_create_profile_request_no_options(self):
request = profiling._create_profile_request("gs://bucket/dir")
Expand Down Expand Up @@ -389,6 +422,7 @@ def test_create_profile_request_no_options(self):
},
},),
)

def test_start_pathways_trace_from_profile_request(self, profile_request):
profiling._start_pathways_trace_from_profile_request(profile_request)

Expand All @@ -412,6 +446,15 @@ def test_original_stop_trace_called_on_stop_failure(self):
profiling.stop_trace()
self.mock_original_stop_trace.assert_called_once()

def test_jax_profiler_trace_calls_patched_functions(self):
mocks = self._setup_monkey_patch()

with jax.profiler.trace("gs://bucket/dir"):
pass

mocks["start_trace"].assert_called_once()
mocks["stop_trace"].assert_called_once()


if __name__ == "__main__":
absltest.main()