Skip to content

Conversation

@ptrendx
Copy link
Member

@ptrendx ptrendx commented Dec 1, 2025

Description

This PR includes a few performance optimizations targeting the CPU overhead. The code, perf numbers etc. are WIP. The code gets kind of ugly though :-(.

For the prepare_forward changes I did not touch attention (@cyanguwa FYI) since it has multiple exit points from the forward and was worried that I would miss something there - it would be great if we could refactor that part first to have a single return statement instead.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@ptrendx
Copy link
Member Author

ptrendx commented Dec 1, 2025

/te-ci pytorch

Comment on lines 644 to 646
def fast_set_attr(self, name: str, value: Any) -> None:
self.__dict__[name] = value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we are separating out this function so we can manually avoid overheads from __setattr__ and dict? Doing some benchmarking:

  • dict read: 9 ns
  • dict write: 13 ns
  • dict in: 9 ns
  • dict.get: 14 ns
  • Function call: 9 ns
  • Class attr read: 3 ns
  • Class attr write: 5 ns
  • Class custom getattr: 101 ns
  • Class custom setattr: 134 ns
Benchmarking script

I ran the following on a GB200 node. For the dict times, I subtracted out the overhead from list reads. For the class getattr/setattr times, I subtracted out the overhead from range.

import contextlib
import time

class Timer:
    """Measure time interval."""

    def __init__(self) -> None:
        self._start = None
        self._end = None

    def time(self) -> float:
	"""CPU time interval in seconds."""
        return self._end - self._start

    @contextlib.contextmanager
    def context(self):
        """Context manager to capture time interval."""
	self._start = time.perf_counter()
        yield
        self._end = time.perf_counter()

def main() -> None:

    # Options
    iters = 1024 * 1024

    # Timer
    timer = Timer()

    # Dummy data
    str_list = ["lorem", "ipsum", "dolor", "sit", "amet", "consectetur", "adipiscing", "elit"]
    str_list = [str_list[i % len(str_list)] for i in range(iters)]
    str_dict = {s: len(s) for s in str_list}
    class PlainClass:
        def __init__(self) -> None:
            self.attr = 1
    class CustomGetattrSetattrClass:
        def __init__(self) -> None:
            self.attr = 1
        def __getattribute__(self, name):
            return super().__getattribute__(name)
	def __setattr__(self, name, val):
            super().__setattr__(name, val)

    # Timer overhead
    with timer.context():
        pass
    print(f"Timer overhead: {timer.time() * 1e9 / iters} ns/iter")

    # Range loop
    with timer.context():
        for _ in range(iters):
            pass
    print(f"Range loop: {timer.time() * 1e9 / iters} ns/iter")

    # List loop
    with timer.context():
        for _ in str_list:
            pass
    print(f"List loop: {timer.time() * 1e9 / iters} ns/iter")

    # Empty range+enumerate loop
    with timer.context():
        for i, j in enumerate(range(iters)):
            pass
    print(f"Range+enumerate loop: {timer.time() * 1e9 / iters} ns/iter")

    # Empty range+enumerate loop
    with timer.context():
        for i, s in enumerate(str_list):
            pass
    print(f"List+enumerate loop: {timer.time() * 1e9 / iters} ns/iter")

    # List reads
    with timer.context():
        for i in range(iters):
            str_list[i]
    print(f"List reads: {timer.time() * 1e9 / iters} ns/iter")

    # Dict reads
    with timer.context():
        for i in range(iters):
            str_dict[str_list[i]]
    print(f"Dict reads: {timer.time() * 1e9 / iters} ns/iter")

    # Dict get
    with timer.context():
        for i in range(iters):
            str_dict.get(str_list[i], None)
    print(f"Dict gets: {timer.time() * 1e9 / iters} ns/iter")

    # Dict writes
    with timer.context():
        for i in range(iters):
            str_dict[str_list[i]] = i
    print(f"Dict writes: {timer.time() * 1e9 / iters} ns/iter")

    # Dict membership
    with timer.context():
        for i in range(iters):
            str_list[i] in str_dict
    print(f"Dict membership: {timer.time() * 1e9 / iters} ns/iter")

    # Function call
    def func() -> None:
        pass
    with timer.context():
        for _ in range(iters):
            func()
    print(f"Function call: {timer.time() * 1e9 / iters} ns/iter")

    # Function call
    func = lambda: None
    with timer.context():
        for _ in range(iters):
            func()
    print(f"Lambda call: {timer.time() * 1e9 / iters} ns/iter")

    # Class attr read
    myobj = PlainClass()
    with timer.context():
        for _ in range(iters):
            _ = myobj.attr
    print(f"Class attr read: {timer.time() * 1e9 / iters} ns/iter")

    # Class attr write
    myobj = PlainClass()
    with timer.context():
        for i in range(iters):
            myobj.attr = i
    print(f"Class attr write: {timer.time() * 1e9 / iters} ns/iter")

    # getattr
    myobj = PlainClass()
    with timer.context():
        for _ in range(iters):
            getattr(myobj, "attr", None)
    print(f"getattr: {timer.time() * 1e9 / iters} ns/iter")

    # getattr
    myobj = PlainClass()
    with timer.context():
        for i in range(iters):
            setattr(myobj, "attr", i)
    print(f"setattr: {timer.time() * 1e9 / iters} ns/iter")

    # Class custom getattr
    myobj = CustomGetattrSetattrClass()
    with timer.context():
        for _ in range(iters):
            _ = myobj.attr
    print(f"Class custom getattr: {timer.time() * 1e9 / iters} ns/iter")

    # Class custom setattr
    myobj = CustomGetattrSetattrClass()
    with timer.context():
        for i in range(iters):
            myobj.attr = i
    print(f"Class custom setattr: {timer.time() * 1e9 / iters} ns/iter")

if __name__ == "__main__":
    main()

How much perf difference do you observe from fast_set_attr? I could see how it could save us ~1 us of overhead, but it would be good to make sure before making the code messier.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to comment too much on the perf results yet since up till now they all come from my machine and not a real cluster, but that anecdotal evidence shows that the time of the small test of just running BF16 Linear layer forward for many iterations after the proposed code changes go from 9.2 to 7.7 s. The fast_set_attr alone brought it to ~8.4s.
I will test it properly and report the timings in the description of the PR.
Now, about introducing the separate function - since ultimately this is the optimization that you came up with at some point, there already was the machinery to not do the expensive Module.set_attr for some parameters. The problem that I see is discoverability - if people do not study that code very cautiously they will not realize that they should not just do self.something = something. Therefore I think we should actually go a more explicit way and in the set_attr of TE module just error out with a message to either use fast_set_attr for the things we are sure are just small values (since the usage of dict directly has some problems BTW since it e.g. bypasses properties and stuff) and use a new function, let's call it just set_attr for anything where we need the full machinery.

Copy link
Collaborator

@timmoon10 timmoon10 Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer not to ban self.something = something. I think readability and safety are more important for non-performance-critical things like initialization and checkpointing. It would be better to make this function an advanced internal implementation with a name like _fast_setattr.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would we then make sure that this does not resurface in the future?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Went with the explicit setattr calls and having a warning issued when the regular setattr function is used. That way the users can still use the regular setattr call if they want, but for the internal development we make sure during testing that the warning does not trigger. To make the code less ugly we only turn on the warning after the constructor is finished - that way we can still use the nice syntax during construction (where there are the most occurences) since we do not care about the speed there.

@ptrendx ptrendx force-pushed the pr_python_cpu_optimization branch from 5eefe3e to 1c7d896 Compare December 2, 2025 22:45
@ptrendx ptrendx force-pushed the pr_python_cpu_optimization branch 2 times, most recently from d149f6b to 2ce0f34 Compare December 15, 2025 22:39
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx force-pushed the pr_python_cpu_optimization branch from 948747b to c4e380f Compare December 16, 2025 21:20
pre-commit-ci bot and others added 7 commits December 16, 2025 21:21
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx
Copy link
Member Author

ptrendx commented Jan 10, 2026

/te-ci pytorch

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx marked this pull request as ready for review January 10, 2026 00:48
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 10, 2026

Greptile Overview

Greptile Summary

This PR implements CPU performance optimizations for PyTorch modules in Transformer Engine by reducing the overhead of Python's __setattr__ mechanism and NVTX profiling setup. The key changes include:

Core Optimizations:

  • Introduced fast_setattr() method that directly writes to __dict__ to bypass PyTorch's module attribute handling overhead
  • Modified __setattr__ to emit RuntimeWarning when called after initialization, enforcing use of fast_setattr()
  • Refactored prepare_forward() from context manager to explicit call pattern with separate end_forward() for reduced CPU overhead
  • Added prepare_forward_ctx() as a wrapper for modules that need the context manager semantics (attention modules)
  • Optimized C++ allocator by moving mutex lock after early return checks

Test Infrastructure:

  • Added pytest.ini to treat RuntimeWarning as errors, enforcing the fast_setattr() usage pattern
  • Updated all test runner scripts to use the new pytest configuration

Critical Issues Found:

  1. Indentation Bug (linear.py line 431): The mark_not_offload() call was incorrectly indented inside the if cpu_offloading: block. This is a critical logic error that will break offloading functionality when CPU offloading is disabled.

  2. NVTX Range Stack Leak (4 files): The refactoring from context managers to explicit prepare_forward()/end_forward() calls introduces exception-safety issues. If an exception occurs between these calls, the NVTX range pushed in prepare_forward() will never be popped, corrupting the profiling stack. This affects:

    • linear.py
    • layernorm_linear.py
    • layernorm_mlp.py
    • grouped_linear.py

    The attention module correctly continues using the context manager pattern via prepare_forward_ctx().

Architecture Notes:
The PR maintains backward compatibility by keeping both the context manager version (prepare_forward_ctx) and the explicit version (prepare_forward + end_forward). However, the explicit version requires careful exception handling to prevent resource leaks, which is currently missing.

Confidence Score: 1/5

  • This PR has critical bugs that will cause runtime failures and profiling corruption
  • Score of 1 reflects two critical issues: (1) an indentation bug in linear.py that breaks the mark_not_offload logic - this will cause incorrect behavior when CPU offloading is disabled, and (2) NVTX range stack leaks in 4 module files due to missing exception handling. The NVTX leak may not cause immediate crashes but will corrupt profiling data and could accumulate over time. The indentation bug is a clear logic error that must be fixed before merge. While the performance optimization approach is sound, the execution introduces these serious bugs that outweigh the benefits.
  • Critical attention needed for transformer_engine/pytorch/module/linear.py (indentation bug at line 431), and exception handling must be added to linear.py, layernorm_linear.py, layernorm_mlp.py, and grouped_linear.py to prevent NVTX stack leaks

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/module/linear.py 1/5 Critical indentation bug in mark_not_offload call; NVTX range leak risk in forward method
transformer_engine/pytorch/module/base.py 3/5 Refactored setattr for performance; split prepare_forward into non-context and context versions; uses fast_setattr throughout
transformer_engine/pytorch/module/layernorm_linear.py 2/5 NVTX range leak risk due to missing exception handling in forward method; sets _initialized flag
transformer_engine/pytorch/module/layernorm_mlp.py 2/5 NVTX range leak risk due to missing exception handling in forward method; sets _initialized flag
transformer_engine/pytorch/module/grouped_linear.py 2/5 NVTX range leak risk due to missing exception handling in forward method; sets _initialized flag
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py 4/5 Uses fast_setattr for performance optimization; correctly uses prepare_forward_ctx context manager; sets _initialized flag
transformer_engine/common/transformer_engine.cpp 4/5 Minor performance optimization: moves mutex lock after early return check in Free methods
tests/pytorch/pytest.ini 5/5 New pytest configuration to treat RuntimeWarning as errors, enforcing proper use of fast_setattr

Sequence Diagram

sequenceDiagram
    participant User
    participant Linear as Linear.forward()
    participant Base as TransformerEngineBaseModule
    participant NVTX as NVTX Profiler
    participant FP8 as FP8GlobalStateManager

    User->>Linear: forward(inp)
    
    Note over Linear: Check ONNX export (early return if true)
    
    Linear->>Base: prepare_forward(inp)
    Base->>Base: get_activation_dtype()
    Base->>Base: init_fp8_metadata()
    Base->>FP8: is_fp8_enabled()
    FP8-->>Base: fp8_enabled
    Base->>Base: fast_setattr("fp8", fp8_enabled)
    
    alt FP8 enabled
        Base->>FP8: copy_forward_fp8_meta_tensors_for_recompute()
    end
    
    Base->>NVTX: nvtx_range_push("Linear forward")
    Base->>Base: Make input contiguous if needed
    Base-->>Linear: inp (prepared)
    
    Note over Linear: RISK: Exception here causes NVTX leak
    
    Linear->>Linear: _get_weight_and_bias_tensors()
    Linear->>Linear: _get_quantizers()
    Linear->>Linear: linear_fn(*args)
    
    Note over Linear: All processing happens here
    
    Linear->>Base: end_forward()
    
    alt FP8 with DelayedScaling
        Base->>FP8: restore_fp8_meta_tensors()
    end
    
    Base->>NVTX: nvtx_range_pop()
    
    Linear-->>User: output
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +1402 to +1474
inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))

weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()

quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)

(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]

non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]

non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
self.end_forward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The refactoring from context manager to explicit prepare_forward() and end_forward() calls introduces an NVTX range stack leak risk. If an exception is raised anywhere between line 1402 (where prepare_forward() calls nvtx_range_push()) and line 1474 (where end_forward() calls nvtx_range_pop()), the NVTX range will never be popped from the stack.

The original context manager approach (with self.prepare_forward()) was safer because it guaranteed cleanup even when exceptions occurred. Without proper exception handling, NVTX ranges can accumulate on the stack, corrupting profiling data.

Consider wrapping the code between prepare_forward() and end_forward() in a try-finally block:

inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))
try:
    # ... all the processing code ...
    out = linear_fn(...)
finally:
    self.end_forward()

This same issue exists in layernorm_linear.py, layernorm_mlp.py, and grouped_linear.py.

Comment on lines +1519 to +1601
inp = self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp:
)

# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()

quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)

(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)

self.end_forward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same NVTX range stack leak vulnerability as in linear.py. If an exception occurs between prepare_forward() (line 1519) and end_forward() (line 1601), the NVTX range pushed in prepare_forward() will never be popped, corrupting the profiling stack.

The original context manager pattern was exception-safe. Consider using try-finally to ensure end_forward() is always called:

inp = self.prepare_forward(inp, allow_non_contiguous=False)
try:
    # ... processing code ...
    out = fwd_fn(...)
finally:
    self.end_forward()

Comment on lines +2057 to +2159
inp = self.prepare_forward(inp, num_gemms=2)

quantizers = (
self._get_quantizers(fp8_output, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, is_grad_enabled)

# Get quantizers
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = quantizers

# Get weight tensors
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.dequantize()

# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.bias_gelu_nvfusion = False
# Get quantizers
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormMLP.forward
autograd_ctx = [None]

non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.activation_params,
self.normalization,
self.ub_overlap_ag,
self.ub_overlap_rs,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.gemm_gelu_fusion and not debug,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.checkpoint,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
# Get weight tensors
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.dequantize()

# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.bias_gelu_nvfusion = False

if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormMLP.forward
autograd_ctx = [None]

non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.activation_params,
self.normalization,
self.ub_overlap_ag,
self.ub_overlap_rs,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.gemm_gelu_fusion and not debug,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.checkpoint,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same NVTX range stack leak vulnerability as in linear.py and layernorm_linear.py. If an exception occurs between prepare_forward() (line 2057) and end_forward() (line 2159), the NVTX range will leak.

Consider using try-finally to ensure cleanup:

inp = self.prepare_forward(inp, num_gemms=2)
try:
    # ... processing code ...
    out = fwd_fn(...)
finally:
    self.end_forward()

Comment on lines +793 to +847
inp = self.prepare_forward(inp, num_gemms=self.num_gemms)
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]

quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()
quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()

if debug:
if self.no_debug_features_active(list(chain(*quantizers))):
debug = False
quantizers = self._get_quantizers()
if debug:
if self.no_debug_features_active(list(chain(*quantizers))):
debug = False
quantizers = self._get_quantizers()

if isinstance(weight_tensors, QuantizedTensorStorage):
raise RuntimeError("FP8 weights are not supported in debug mode.")
if isinstance(weight_tensors, QuantizedTensorStorage):
raise RuntimeError("FP8 weights are not supported in debug mode.")

(
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
) = quantizers
(
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
) = quantizers

if is_grad_enabled:
linear_fn = _GroupedLinear.apply
autograd_ctx = []
else:
linear_fn = _GroupedLinear.forward
autograd_ctx = [None]

non_tensor_args = (
m_splits,
self.apply_bias,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
is_grad_enabled,
self,
None, # skip_fp8_weight_update
self.save_original_input,
debug,
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
if is_grad_enabled:
linear_fn = _GroupedLinear.apply
autograd_ctx = []
else:
linear_fn = _GroupedLinear.forward
autograd_ctx = [None]

non_tensor_args = (
m_splits,
self.apply_bias,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
is_grad_enabled,
self,
None, # skip_fp8_weight_update
self.save_original_input,
debug,
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same NVTX range stack leak vulnerability. If an exception occurs between prepare_forward() (line 793) and end_forward() (line 847), the NVTX range will leak.

Consider using try-finally to ensure cleanup:

inp = self.prepare_forward(inp, num_gemms=self.num_gemms)
try:
    # ... processing code ...
    out = linear_fn(...)
finally:
    self.end_forward()

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 10, 2026

Greptile Overview

Greptile Summary

This PR attempts to optimize CPU overhead by:

  1. Replacing context managers with manual prepare_forward()/end_forward() calls
  2. Introducing fast_setattr() to bypass PyTorch's __setattr__ overhead
  3. Adding an __setattr__ override that warns when the slow path is used
  4. Optimizing C++ mutex locking by moving null checks before lock acquisition
  5. Configuring pytest to treat RuntimeWarnings as errors

Critical Issues Found

1. NVTX Range Imbalance on Exceptions (HIGH SEVERITY)

The refactoring from context managers to manual prepare_forward()/end_forward() calls breaks exception safety. If an exception occurs between these calls (e.g., shape mismatch, CUDA OOM, assertion failure), nvtx_range_pop() is never called, corrupting the NVTX stack. This affects all forward methods in Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear modules.

2. setattr Doesn't Actually Optimize (HIGH SEVERITY)

The new __setattr__ override still calls super().__setattr__(name, value) after emitting a warning, meaning every attribute assignment still goes through the slow PyTorch path. This defeats the purpose of the optimization.

3. Multiple RuntimeWarning Violations (CRITICAL SEVERITY)

Six locations in base.py use direct attribute assignment after initialization (lines 965, 966, 1558, 1559, 1565, 1581, 1608). Since pytest.ini now treats RuntimeWarnings as errors, all tests will fail.

Positive Aspects

  • C++ mutex optimization is correct and beneficial
  • Attention module correctly uses prepare_forward_ctx context manager
  • All module subclasses properly set _initialized flag
  • Test scripts correctly updated to use pytest.ini

Recommendation

This PR cannot be merged in its current state due to the RuntimeWarning violations that will cause all tests to fail. The NVTX exception safety issue is also critical for production use.

Confidence Score: 0/5

  • This PR is not safe to merge - it will cause all tests to fail due to RuntimeWarning violations
  • Score of 0 reflects critical issues that will break the build: (1) Six direct attribute assignments trigger RuntimeWarnings which pytest.ini treats as errors, causing all tests to fail immediately; (2) NVTX range imbalance on exceptions will corrupt profiling; (3) setattr optimization doesn't actually work as intended
  • transformer_engine/pytorch/module/base.py requires immediate attention - must fix all direct attribute assignments (lines 965, 966, 1558, 1559, 1565, 1581, 1608) and address exception safety for NVTX ranges

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/module/base.py 1/5 Critical bugs: NVTX range imbalance on exceptions, setattr doesn't optimize, multiple RuntimeWarning-triggering assignments
transformer_engine/pytorch/module/linear.py 2/5 Changed prepare_forward context manager to manual calls, adds _initialized flag, affected by base.py NVTX bug
transformer_engine/pytorch/module/layernorm_linear.py 2/5 Changed prepare_forward context manager to manual calls, adds _initialized flag, affected by base.py NVTX bug
transformer_engine/pytorch/module/layernorm_mlp.py 2/5 Changed prepare_forward context manager to manual calls, adds _initialized flag, affected by base.py NVTX bug
transformer_engine/pytorch/module/grouped_linear.py 2/5 Changed prepare_forward context manager to manual calls, adds _initialized flag, affected by base.py NVTX bug
tests/pytorch/pytest.ini 3/5 New file that treats RuntimeWarning as error - exposes issues with setattr implementation in base.py

Sequence Diagram

sequenceDiagram
    participant User
    participant Module as Linear/LayerNorm Module
    participant Base as TransformerEngineBaseModule
    participant NVTX as NVTX Range Stack
    participant FP8 as FP8GlobalStateManager

    User->>Module: forward(inp)
    Module->>Base: prepare_forward(inp)
    Base->>Base: init_fp8_metadata()
    Base->>Base: get_amp_dtype()
    Base->>FP8: Copy FP8 metadata (if recompute)
    Base->>NVTX: nvtx_range_push("Module forward")
    Note over NVTX: Range pushed - needs cleanup!
    Base-->>Module: returns processed inp
    
    alt Exception occurs during forward
        Module->>Module: Compute quantizers/weights
        Module--xModule: Exception raised!
        Note over NVTX: ⚠️ BUG: nvtx_range_pop() never called!
        Note over NVTX: Stack becomes imbalanced
    else Normal execution
        Module->>Module: Compute quantizers/weights
        Module->>Module: Call linear_fn()
        Module->>Base: end_forward()
        Base->>FP8: Restore FP8 metadata (if recompute)
        Base->>NVTX: nvtx_range_pop()
        Note over NVTX: Range properly cleaned up
        Base-->>Module: return
        Module-->>User: return output
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 7 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +1086 to +1094
def end_forward(self):
"""
Required to be called at the end of the forward function to properly handle
DelayedScaling metadata handling and the NVTX ranges.
"""
delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
nvtx_range_pop()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new non-context-manager approach to prepare_forward() and end_forward() creates a critical bug: if an exception is raised between the prepare_forward() call and the end_forward() call in any forward method, the NVTX range will never be popped, leading to an imbalanced NVTX stack.

The original context manager approach with with self.prepare_forward(...) as inp: guaranteed that cleanup would happen even on exceptions. The new approach breaks this guarantee.

Impact: All forward methods in Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear are affected. Any exception during forward (e.g., shape mismatch, CUDA OOM, assertion failure) will leave the NVTX range stack corrupted.

Recommendation: Either:

  1. Revert to the context manager approach (accepting the performance cost), or
  2. Wrap the forward logic in try/finally to ensure end_forward() is always called:
inp = self.prepare_forward(inp, ...)
try:
    # forward logic here
    out = linear_fn(...)
finally:
    self.end_forward()

Comment on lines 651 to +660
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
if "_initialized" in self.__dict__ and self._initialized:
warnings.warn(
"""The default implementation of torch.nn.Module introduces significant CPU overhead
when setting attributes and is therefore not recommended. Please use the explicit
calls (fast_setattr for setting regular values and module_setattr for setting
parameters, children modules and buffers).""",
RuntimeWarning,
)
super().__setattr__(name, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __setattr__ override doesn't actually optimize anything - it still calls super().__setattr__(name, value) after emitting the warning. This means every attribute assignment still goes through the slow PyTorch Module path, defeating the purpose of the optimization.

The warning is also emitted on EVERY attribute assignment after initialization, which would be extremely noisy and unhelpful for debugging.

Additionally, there are multiple places in the codebase that still use direct attribute assignment after initialization:

  • Line 965: self.tp_group = tp_group in set_tensor_parallel_group()
  • Line 966: self.tp_group_initialized = True
  • Lines 1558, 1559, 1565, 1581, 1608: Various debug-related attribute assignments

Since pytest.ini now treats RuntimeWarnings as errors, all of these will cause test failures.

Recommendation: Either:

  1. Remove the __setattr__ override entirely and rely on developers remembering to use fast_setattr, or
  2. Update all the remaining direct assignments to use fast_setattr or module_setattr

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 10, 2026

Additional Comments (5)

transformer_engine/pytorch/module/base.py
Direct attribute assignments after initialization will trigger the RuntimeWarning defined in __setattr__, which pytest.ini now treats as an error. This will cause test failures.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
Direct attribute assignments after initialization will trigger the RuntimeWarning defined in __setattr__, which pytest.ini now treats as an error. This will cause test failures.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
Direct attribute assignment after initialization will trigger the RuntimeWarning defined in __setattr__, which pytest.ini now treats as an error. This will cause test failures.

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

transformer_engine/pytorch/module/base.py
Direct attribute assignment after initialization will trigger the RuntimeWarning defined in __setattr__, which pytest.ini now treats as an error. This will cause test failures.

        self.fast_setattr("next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers))

transformer_engine/pytorch/module/base.py
Direct attribute assignment after initialization will trigger the RuntimeWarning defined in __setattr__, which pytest.ini now treats as an error. This will cause test failures.

            self.fast_setattr("name", f"Layer_{TEDebugState.get_layer_count()}")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants