diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index f8ba47f..d6c539d 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -286,12 +286,14 @@ def start_trace_patch( ) jax.profiler.start_trace = start_trace_patch + jax._src.profiler.start_trace = start_trace_patch # pylint: disable=protected-access def stop_trace_patch() -> None: _logger.debug("jax.profile.stop_trace patched with pathways' stop_trace") return stop_trace() jax.profiler.stop_trace = stop_trace_patch + jax._src.profiler.stop_trace = stop_trace_patch # pylint: disable=protected-access def start_server_patch(port: int): _logger.debug( diff --git a/pathwaysutils/test/profiling_test.py b/pathwaysutils/test/profiling_test.py index fd87093..f62b36c 100644 --- a/pathwaysutils/test/profiling_test.py +++ b/pathwaysutils/test/profiling_test.py @@ -305,53 +305,86 @@ def test_stop_server_does_nothing_if_server_exists(self): profiling.start_server(9000) profiling.stop_server() # Should not raise - def test_monkey_patch_jax(self): - original_jax_start_trace = jax.profiler.start_trace - original_jax_stop_trace = jax.profiler.stop_trace - original_jax_start_server = jax.profiler.start_server - original_jax_stop_server = jax.profiler.stop_server + def _setup_monkey_patch(self): + """Saves originals, applies monkey patch, and sets up mocks.""" + targets = [ + (jax.profiler, "start_trace"), + (jax.profiler, "stop_trace"), + (jax.profiler, "start_server"), + (jax.profiler, "stop_server"), + (jax._src.profiler, "start_trace"), + (jax._src.profiler, "stop_trace"), + ] + original_jax_funcs = {} + for module, func_name in targets: + original_func = getattr(module, func_name) + original_jax_funcs[(module, func_name)] = original_func + self.addCleanup(setattr, module, func_name, original_func) profiling.monkey_patch_jax() - self.assertNotEqual(jax.profiler.start_trace, original_jax_start_trace) - self.assertNotEqual(jax.profiler.stop_trace, original_jax_stop_trace) - self.assertNotEqual(jax.profiler.start_server, original_jax_start_server) - self.assertNotEqual(jax.profiler.stop_server, original_jax_stop_server) - - with mock.patch.object( - profiling, "start_trace", autospec=True - ) as mock_pw_start_trace: - jax.profiler.start_trace("gs://bucket/dir") - mock_pw_start_trace.assert_called_once_with( - "gs://bucket/dir", - create_perfetto_link=False, - create_perfetto_trace=False, - profiler_options=None, + for module, func_name in targets: + self.assertNotEqual( + getattr(module, func_name), + original_jax_funcs[(module, func_name)], ) - with mock.patch.object( - profiling, "stop_trace", autospec=True - ) as mock_pw_stop_trace: - jax.profiler.stop_trace() - mock_pw_stop_trace.assert_called_once() - - with mock.patch.object( - profiling, "start_server", autospec=True - ) as mock_pw_start_server: - jax.profiler.start_server(1234) - mock_pw_start_server.assert_called_once_with(1234) - - with mock.patch.object( - profiling, "stop_server", autospec=True - ) as mock_pw_stop_server: - jax.profiler.stop_server() - mock_pw_stop_server.assert_called_once() - - # Restore original jax functions - jax.profiler.start_trace = original_jax_start_trace - jax.profiler.stop_trace = original_jax_stop_trace - jax.profiler.start_server = original_jax_start_server - jax.profiler.stop_server = original_jax_stop_server + mocks = { + "start_trace": self.enter_context( + mock.patch.object(profiling, "start_trace", autospec=True) + ), + "stop_trace": self.enter_context( + mock.patch.object(profiling, "stop_trace", autospec=True) + ), + "start_server": self.enter_context( + mock.patch.object(profiling, "start_server", autospec=True) + ), + "stop_server": self.enter_context( + mock.patch.object(profiling, "stop_server", autospec=True) + ), + } + return mocks + + @parameterized.named_parameters( + dict(testcase_name="jax_profiler", profiler_module=jax.profiler), + dict(testcase_name="jax_src_profiler", profiler_module=jax._src.profiler), + ) + def test_monkey_patched_start_trace(self, profiler_module): + mocks = self._setup_monkey_patch() + + profiler_module.start_trace("gs://bucket/dir") + + mocks["start_trace"].assert_called_once_with( + "gs://bucket/dir", + create_perfetto_link=False, + create_perfetto_trace=False, + profiler_options=None, + ) + + @parameterized.named_parameters( + dict(testcase_name="jax_profiler", profiler_module=jax.profiler), + dict(testcase_name="jax_src_profiler", profiler_module=jax._src.profiler), + ) + def test_monkey_patched_stop_trace(self, profiler_module): + mocks = self._setup_monkey_patch() + + profiler_module.stop_trace() + + mocks["stop_trace"].assert_called_once() + + def test_monkey_patched_start_server(self): + mocks = self._setup_monkey_patch() + + jax.profiler.start_server(1234) + + mocks["start_server"].assert_called_once_with(1234) + + def test_monkey_patched_stop_server(self): + mocks = self._setup_monkey_patch() + + jax.profiler.stop_server() + + mocks["stop_server"].assert_called_once() def test_create_profile_request_no_options(self): request = profiling._create_profile_request("gs://bucket/dir") @@ -389,6 +422,7 @@ def test_create_profile_request_no_options(self): }, },), ) + def test_start_pathways_trace_from_profile_request(self, profile_request): profiling._start_pathways_trace_from_profile_request(profile_request) @@ -412,6 +446,15 @@ def test_original_stop_trace_called_on_stop_failure(self): profiling.stop_trace() self.mock_original_stop_trace.assert_called_once() + def test_jax_profiler_trace_calls_patched_functions(self): + mocks = self._setup_monkey_patch() + + with jax.profiler.trace("gs://bucket/dir"): + pass + + mocks["start_trace"].assert_called_once() + mocks["stop_trace"].assert_called_once() + if __name__ == "__main__": absltest.main()