-
Notifications
You must be signed in to change notification settings - Fork 603
[pyTorch] CPU performance optimizations #2439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
/te-ci pytorch |
| def fast_set_attr(self, name: str, value: Any) -> None: | ||
| self.__dict__[name] = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume we are separating out this function so we can manually avoid overheads from __setattr__ and dict? Doing some benchmarking:
dictread: 9 nsdictwrite: 13 nsdictin: 9 nsdict.get: 14 ns- Function call: 9 ns
- Class attr read: 3 ns
- Class attr write: 5 ns
- Class custom getattr: 101 ns
- Class custom setattr: 134 ns
Benchmarking script
I ran the following on a GB200 node. For the dict times, I subtracted out the overhead from list reads. For the class getattr/setattr times, I subtracted out the overhead from range.
import contextlib
import time
class Timer:
"""Measure time interval."""
def __init__(self) -> None:
self._start = None
self._end = None
def time(self) -> float:
"""CPU time interval in seconds."""
return self._end - self._start
@contextlib.contextmanager
def context(self):
"""Context manager to capture time interval."""
self._start = time.perf_counter()
yield
self._end = time.perf_counter()
def main() -> None:
# Options
iters = 1024 * 1024
# Timer
timer = Timer()
# Dummy data
str_list = ["lorem", "ipsum", "dolor", "sit", "amet", "consectetur", "adipiscing", "elit"]
str_list = [str_list[i % len(str_list)] for i in range(iters)]
str_dict = {s: len(s) for s in str_list}
class PlainClass:
def __init__(self) -> None:
self.attr = 1
class CustomGetattrSetattrClass:
def __init__(self) -> None:
self.attr = 1
def __getattribute__(self, name):
return super().__getattribute__(name)
def __setattr__(self, name, val):
super().__setattr__(name, val)
# Timer overhead
with timer.context():
pass
print(f"Timer overhead: {timer.time() * 1e9 / iters} ns/iter")
# Range loop
with timer.context():
for _ in range(iters):
pass
print(f"Range loop: {timer.time() * 1e9 / iters} ns/iter")
# List loop
with timer.context():
for _ in str_list:
pass
print(f"List loop: {timer.time() * 1e9 / iters} ns/iter")
# Empty range+enumerate loop
with timer.context():
for i, j in enumerate(range(iters)):
pass
print(f"Range+enumerate loop: {timer.time() * 1e9 / iters} ns/iter")
# Empty range+enumerate loop
with timer.context():
for i, s in enumerate(str_list):
pass
print(f"List+enumerate loop: {timer.time() * 1e9 / iters} ns/iter")
# List reads
with timer.context():
for i in range(iters):
str_list[i]
print(f"List reads: {timer.time() * 1e9 / iters} ns/iter")
# Dict reads
with timer.context():
for i in range(iters):
str_dict[str_list[i]]
print(f"Dict reads: {timer.time() * 1e9 / iters} ns/iter")
# Dict get
with timer.context():
for i in range(iters):
str_dict.get(str_list[i], None)
print(f"Dict gets: {timer.time() * 1e9 / iters} ns/iter")
# Dict writes
with timer.context():
for i in range(iters):
str_dict[str_list[i]] = i
print(f"Dict writes: {timer.time() * 1e9 / iters} ns/iter")
# Dict membership
with timer.context():
for i in range(iters):
str_list[i] in str_dict
print(f"Dict membership: {timer.time() * 1e9 / iters} ns/iter")
# Function call
def func() -> None:
pass
with timer.context():
for _ in range(iters):
func()
print(f"Function call: {timer.time() * 1e9 / iters} ns/iter")
# Function call
func = lambda: None
with timer.context():
for _ in range(iters):
func()
print(f"Lambda call: {timer.time() * 1e9 / iters} ns/iter")
# Class attr read
myobj = PlainClass()
with timer.context():
for _ in range(iters):
_ = myobj.attr
print(f"Class attr read: {timer.time() * 1e9 / iters} ns/iter")
# Class attr write
myobj = PlainClass()
with timer.context():
for i in range(iters):
myobj.attr = i
print(f"Class attr write: {timer.time() * 1e9 / iters} ns/iter")
# getattr
myobj = PlainClass()
with timer.context():
for _ in range(iters):
getattr(myobj, "attr", None)
print(f"getattr: {timer.time() * 1e9 / iters} ns/iter")
# getattr
myobj = PlainClass()
with timer.context():
for i in range(iters):
setattr(myobj, "attr", i)
print(f"setattr: {timer.time() * 1e9 / iters} ns/iter")
# Class custom getattr
myobj = CustomGetattrSetattrClass()
with timer.context():
for _ in range(iters):
_ = myobj.attr
print(f"Class custom getattr: {timer.time() * 1e9 / iters} ns/iter")
# Class custom setattr
myobj = CustomGetattrSetattrClass()
with timer.context():
for i in range(iters):
myobj.attr = i
print(f"Class custom setattr: {timer.time() * 1e9 / iters} ns/iter")
if __name__ == "__main__":
main()How much perf difference do you observe from fast_set_attr? I could see how it could save us ~1 us of overhead, but it would be good to make sure before making the code messier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to comment too much on the perf results yet since up till now they all come from my machine and not a real cluster, but that anecdotal evidence shows that the time of the small test of just running BF16 Linear layer forward for many iterations after the proposed code changes go from 9.2 to 7.7 s. The fast_set_attr alone brought it to ~8.4s.
I will test it properly and report the timings in the description of the PR.
Now, about introducing the separate function - since ultimately this is the optimization that you came up with at some point, there already was the machinery to not do the expensive Module.set_attr for some parameters. The problem that I see is discoverability - if people do not study that code very cautiously they will not realize that they should not just do self.something = something. Therefore I think we should actually go a more explicit way and in the set_attr of TE module just error out with a message to either use fast_set_attr for the things we are sure are just small values (since the usage of dict directly has some problems BTW since it e.g. bypasses properties and stuff) and use a new function, let's call it just set_attr for anything where we need the full machinery.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer not to ban self.something = something. I think readability and safety are more important for non-performance-critical things like initialization and checkpointing. It would be better to make this function an advanced internal implementation with a name like _fast_setattr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How would we then make sure that this does not resurface in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Went with the explicit setattr calls and having a warning issued when the regular setattr function is used. That way the users can still use the regular setattr call if they want, but for the internal development we make sure during testing that the warning does not trigger. To make the code less ugly we only turn on the warning after the constructor is finished - that way we can still use the nice syntax during construction (where there are the most occurences) since we do not care about the speed there.
5eefe3e to
1c7d896
Compare
d149f6b to
2ce0f34
Compare
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
948747b to
c4e380f
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
|
/te-ci pytorch |
Greptile OverviewGreptile SummaryThis PR implements CPU performance optimizations for PyTorch modules in Transformer Engine by reducing the overhead of Python's Core Optimizations:
Test Infrastructure:
Critical Issues Found:
Architecture Notes: Confidence Score: 1/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant Linear as Linear.forward()
participant Base as TransformerEngineBaseModule
participant NVTX as NVTX Profiler
participant FP8 as FP8GlobalStateManager
User->>Linear: forward(inp)
Note over Linear: Check ONNX export (early return if true)
Linear->>Base: prepare_forward(inp)
Base->>Base: get_activation_dtype()
Base->>Base: init_fp8_metadata()
Base->>FP8: is_fp8_enabled()
FP8-->>Base: fp8_enabled
Base->>Base: fast_setattr("fp8", fp8_enabled)
alt FP8 enabled
Base->>FP8: copy_forward_fp8_meta_tensors_for_recompute()
end
Base->>NVTX: nvtx_range_push("Linear forward")
Base->>Base: Make input contiguous if needed
Base-->>Linear: inp (prepared)
Note over Linear: RISK: Exception here causes NVTX leak
Linear->>Linear: _get_weight_and_bias_tensors()
Linear->>Linear: _get_quantizers()
Linear->>Linear: linear_fn(*args)
Note over Linear: All processing happens here
Linear->>Base: end_forward()
alt FP8 with DelayedScaling
Base->>FP8: restore_fp8_meta_tensors()
end
Base->>NVTX: nvtx_range_pop()
Linear-->>User: output
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 5 comments
| inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) | ||
|
|
||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
|
|
||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _Linear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _Linear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| is_grad_enabled, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_name, | ||
| fp8_output, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn( | ||
| *autograd_ctx, | ||
| weight_tensor, | ||
| inp, | ||
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | ||
| non_tensor_args, | ||
| ) | ||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _Linear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _Linear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| is_grad_enabled, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_name, | ||
| fp8_output, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn( | ||
| *autograd_ctx, | ||
| weight_tensor, | ||
| inp, | ||
| bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, | ||
| non_tensor_args, | ||
| ) | ||
| self.end_forward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The refactoring from context manager to explicit prepare_forward() and end_forward() calls introduces an NVTX range stack leak risk. If an exception is raised anywhere between line 1402 (where prepare_forward() calls nvtx_range_push()) and line 1474 (where end_forward() calls nvtx_range_pop()), the NVTX range will never be popped from the stack.
The original context manager approach (with self.prepare_forward()) was safer because it guaranteed cleanup even when exceptions occurred. Without proper exception handling, NVTX ranges can accumulate on the stack, corrupting profiling data.
Consider wrapping the code between prepare_forward() and end_forward() in a try-finally block:
inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))
try:
# ... all the processing code ...
out = linear_fn(...)
finally:
self.end_forward()This same issue exists in layernorm_linear.py, layernorm_mlp.py, and grouped_linear.py.
| inp = self.prepare_forward( | ||
| inp, allow_non_contiguous=False # removed .contiguous from inside the layer | ||
| ) as inp: | ||
| ) | ||
|
|
||
| # Get concatenated weight and bias tensors | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
| # Get concatenated weight and bias tensors | ||
| weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) | ||
|
|
||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormLinear.forward | ||
| autograd_ctx = [None] | ||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.normalization, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_name, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| weight_tensor, | ||
| bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
| ( | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormLinear.forward | ||
| autograd_ctx = [None] | ||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| input_quantizer, | ||
| weight_quantizer, | ||
| output_quantizer, | ||
| grad_input_quantizer, | ||
| grad_weight_quantizer, | ||
| grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.parallel_mode, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.normalization, | ||
| self.ub_overlap_ag_fprop, | ||
| self.ub_overlap_rs_fprop, | ||
| self.ub_overlap_ag_dgrad, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_name, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| weight_tensor, | ||
| bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
|
|
||
| self.end_forward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same NVTX range stack leak vulnerability as in linear.py. If an exception occurs between prepare_forward() (line 1519) and end_forward() (line 1601), the NVTX range pushed in prepare_forward() will never be popped, corrupting the profiling stack.
The original context manager pattern was exception-safe. Consider using try-finally to ensure end_forward() is always called:
inp = self.prepare_forward(inp, allow_non_contiguous=False)
try:
# ... processing code ...
out = fwd_fn(...)
finally:
self.end_forward()| inp = self.prepare_forward(inp, num_gemms=2) | ||
|
|
||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, is_grad_enabled) | ||
| quantizers = ( | ||
| self._get_quantizers(fp8_output, is_grad_enabled) | ||
| if not debug | ||
| else self._get_debug_quantizers(fp8_output, is_grad_enabled) | ||
| ) | ||
| if debug: | ||
| if self.no_debug_features_active(quantizers): | ||
| debug = False | ||
| quantizers = self._get_quantizers(fp8_output, is_grad_enabled) | ||
|
|
||
| # Get quantizers | ||
| ( | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| # Get weight tensors | ||
| fc1_weight, fc2_weight = self._get_weight_tensors() | ||
| fc1_bias = self.fc1_bias if self.use_bias else None | ||
| fc2_bias = self.fc2_bias if self.use_bias else None | ||
| if not self.fp8: | ||
| if isinstance(fc1_weight, Float8Tensor): | ||
| fc1_weight = fc1_weight.dequantize() | ||
| if isinstance(fc2_weight, Float8Tensor): | ||
| fc2_weight = fc2_weight.dequantize() | ||
|
|
||
| # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode | ||
| if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): | ||
| self.bias_gelu_nvfusion = False | ||
| # Get quantizers | ||
| ( | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormMLP.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormMLP.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| self.bias_gelu_nvfusion and not self.fp8 and not debug, | ||
| self.set_parallel_mode, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.activation, | ||
| self.activation_params, | ||
| self.normalization, | ||
| self.ub_overlap_ag, | ||
| self.ub_overlap_rs, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.gemm_gelu_fusion and not debug, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.checkpoint, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, | ||
| fc1_weight, | ||
| fc1_bias, | ||
| fc2_weight, | ||
| fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, | ||
| non_tensor_args, | ||
| ) | ||
| # Get weight tensors | ||
| fc1_weight, fc2_weight = self._get_weight_tensors() | ||
| fc1_bias = self.fc1_bias if self.use_bias else None | ||
| fc2_bias = self.fc2_bias if self.use_bias else None | ||
| if not self.fp8: | ||
| if isinstance(fc1_weight, Float8Tensor): | ||
| fc1_weight = fc1_weight.dequantize() | ||
| if isinstance(fc2_weight, Float8Tensor): | ||
| fc2_weight = fc2_weight.dequantize() | ||
|
|
||
| # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode | ||
| if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): | ||
| self.bias_gelu_nvfusion = False | ||
|
|
||
| if is_grad_enabled: | ||
| fwd_fn = _LayerNormMLP.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| fwd_fn = _LayerNormMLP.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| self.eps, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| self.fuse_wgrad_accumulation, | ||
| fc1_input_quantizer, | ||
| fc1_weight_quantizer, | ||
| fc1_output_quantizer, | ||
| fc1_grad_input_quantizer, | ||
| fc1_grad_weight_quantizer, | ||
| fc1_grad_output_quantizer, | ||
| fc2_input_quantizer, | ||
| fc2_weight_quantizer, | ||
| fc2_output_quantizer, | ||
| fc2_grad_input_quantizer, | ||
| fc2_grad_weight_quantizer, | ||
| fc2_grad_output_quantizer, | ||
| is_cpu_offload_enabled(), | ||
| self.tp_group, | ||
| self.tp_size, | ||
| self.sequence_parallel, | ||
| self.tp_size > 1, | ||
| self.activation_dtype, | ||
| self.return_layernorm_output, | ||
| self.return_layernorm_output_gathered, | ||
| self.bias_gelu_nvfusion and not self.fp8 and not debug, | ||
| self.set_parallel_mode, | ||
| is_grad_enabled, | ||
| self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, | ||
| self.bwd_ln_sm_margin, | ||
| self.zero_centered_gamma, | ||
| self.activation, | ||
| self.activation_params, | ||
| self.normalization, | ||
| self.ub_overlap_ag, | ||
| self.ub_overlap_rs, | ||
| self.ub_overlap_rs_dgrad, | ||
| self.ub_bulk_dgrad, | ||
| self.ub_bulk_wgrad, | ||
| self.gemm_gelu_fusion and not debug, | ||
| self.fsdp_group, | ||
| self, | ||
| skip_fp8_weight_update, | ||
| self.symmetric_ar_type, | ||
| self.checkpoint, | ||
| debug, | ||
| ) | ||
| out = fwd_fn( | ||
| *autograd_ctx, | ||
| inp, | ||
| self.layer_norm_weight, | ||
| self.layer_norm_bias, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same NVTX range stack leak vulnerability as in linear.py and layernorm_linear.py. If an exception occurs between prepare_forward() (line 2057) and end_forward() (line 2159), the NVTX range will leak.
Consider using try-finally to ensure cleanup:
inp = self.prepare_forward(inp, num_gemms=2)
try:
# ... processing code ...
out = fwd_fn(...)
finally:
self.end_forward()| inp = self.prepare_forward(inp, num_gemms=self.num_gemms) | ||
| weight_tensors = self._get_weight_tensors() | ||
| bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] | ||
|
|
||
| quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() | ||
| quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() | ||
|
|
||
| if debug: | ||
| if self.no_debug_features_active(list(chain(*quantizers))): | ||
| debug = False | ||
| quantizers = self._get_quantizers() | ||
| if debug: | ||
| if self.no_debug_features_active(list(chain(*quantizers))): | ||
| debug = False | ||
| quantizers = self._get_quantizers() | ||
|
|
||
| if isinstance(weight_tensors, QuantizedTensorStorage): | ||
| raise RuntimeError("FP8 weights are not supported in debug mode.") | ||
| if isinstance(weight_tensors, QuantizedTensorStorage): | ||
| raise RuntimeError("FP8 weights are not supported in debug mode.") | ||
|
|
||
| ( | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| ) = quantizers | ||
| ( | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| ) = quantizers | ||
|
|
||
| if is_grad_enabled: | ||
| linear_fn = _GroupedLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _GroupedLinear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| m_splits, | ||
| self.apply_bias, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.sequence_parallel, | ||
| self.activation_dtype, | ||
| is_grad_enabled, | ||
| self, | ||
| None, # skip_fp8_weight_update | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) | ||
| if is_grad_enabled: | ||
| linear_fn = _GroupedLinear.apply | ||
| autograd_ctx = [] | ||
| else: | ||
| linear_fn = _GroupedLinear.forward | ||
| autograd_ctx = [None] | ||
|
|
||
| non_tensor_args = ( | ||
| m_splits, | ||
| self.apply_bias, | ||
| is_first_microbatch, | ||
| self.fp8, | ||
| self.fp8_calibration, | ||
| self.wgrad_store, | ||
| input_quantizers, | ||
| weight_quantizers, | ||
| output_quantizers, | ||
| grad_input_quantizers, | ||
| grad_weight_quantizers, | ||
| grad_output_quantizers, | ||
| self.fuse_wgrad_accumulation, | ||
| is_cpu_offload_enabled(), | ||
| self.sequence_parallel, | ||
| self.activation_dtype, | ||
| is_grad_enabled, | ||
| self, | ||
| None, # skip_fp8_weight_update | ||
| self.save_original_input, | ||
| debug, | ||
| ) | ||
| out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same NVTX range stack leak vulnerability. If an exception occurs between prepare_forward() (line 793) and end_forward() (line 847), the NVTX range will leak.
Consider using try-finally to ensure cleanup:
inp = self.prepare_forward(inp, num_gemms=self.num_gemms)
try:
# ... processing code ...
out = linear_fn(...)
finally:
self.end_forward()
Greptile OverviewGreptile SummaryThis PR attempts to optimize CPU overhead by:
Critical Issues Found1. NVTX Range Imbalance on Exceptions (HIGH SEVERITY)The refactoring from context managers to manual 2. setattr Doesn't Actually Optimize (HIGH SEVERITY)The new 3. Multiple RuntimeWarning Violations (CRITICAL SEVERITY)Six locations in base.py use direct attribute assignment after initialization (lines 965, 966, 1558, 1559, 1565, 1581, 1608). Since pytest.ini now treats RuntimeWarnings as errors, all tests will fail. Positive Aspects
RecommendationThis PR cannot be merged in its current state due to the RuntimeWarning violations that will cause all tests to fail. The NVTX exception safety issue is also critical for production use. Confidence Score: 0/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant Module as Linear/LayerNorm Module
participant Base as TransformerEngineBaseModule
participant NVTX as NVTX Range Stack
participant FP8 as FP8GlobalStateManager
User->>Module: forward(inp)
Module->>Base: prepare_forward(inp)
Base->>Base: init_fp8_metadata()
Base->>Base: get_amp_dtype()
Base->>FP8: Copy FP8 metadata (if recompute)
Base->>NVTX: nvtx_range_push("Module forward")
Note over NVTX: Range pushed - needs cleanup!
Base-->>Module: returns processed inp
alt Exception occurs during forward
Module->>Module: Compute quantizers/weights
Module--xModule: Exception raised!
Note over NVTX: ⚠️ BUG: nvtx_range_pop() never called!
Note over NVTX: Stack becomes imbalanced
else Normal execution
Module->>Module: Compute quantizers/weights
Module->>Module: Call linear_fn()
Module->>Base: end_forward()
Base->>FP8: Restore FP8 metadata (if recompute)
Base->>NVTX: nvtx_range_pop()
Note over NVTX: Range properly cleaned up
Base-->>Module: return
Module-->>User: return output
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7 files reviewed, 7 comments
| def end_forward(self): | ||
| """ | ||
| Required to be called at the end of the forward function to properly handle | ||
| DelayedScaling metadata handling and the NVTX ranges. | ||
| """ | ||
| delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() | ||
| if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): | ||
| FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) | ||
| nvtx_range_pop() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new non-context-manager approach to prepare_forward() and end_forward() creates a critical bug: if an exception is raised between the prepare_forward() call and the end_forward() call in any forward method, the NVTX range will never be popped, leading to an imbalanced NVTX stack.
The original context manager approach with with self.prepare_forward(...) as inp: guaranteed that cleanup would happen even on exceptions. The new approach breaks this guarantee.
Impact: All forward methods in Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear are affected. Any exception during forward (e.g., shape mismatch, CUDA OOM, assertion failure) will leave the NVTX range stack corrupted.
Recommendation: Either:
- Revert to the context manager approach (accepting the performance cost), or
- Wrap the forward logic in try/finally to ensure
end_forward()is always called:
inp = self.prepare_forward(inp, ...)
try:
# forward logic here
out = linear_fn(...)
finally:
self.end_forward()| def __setattr__(self, name: str, value: Any) -> None: | ||
| if name in TransformerEngineBaseModule._fast_setattr_names: | ||
| # torch.nn.Module has a custom __setattr__ that handles | ||
| # modules, parameters, and buffers. This is unnecessary | ||
| # overhead when setting plain attrs. | ||
| self.__dict__[name] = value | ||
| else: | ||
| # Default case | ||
| super().__setattr__(name, value) | ||
| if "_initialized" in self.__dict__ and self._initialized: | ||
| warnings.warn( | ||
| """The default implementation of torch.nn.Module introduces significant CPU overhead | ||
| when setting attributes and is therefore not recommended. Please use the explicit | ||
| calls (fast_setattr for setting regular values and module_setattr for setting | ||
| parameters, children modules and buffers).""", | ||
| RuntimeWarning, | ||
| ) | ||
| super().__setattr__(name, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __setattr__ override doesn't actually optimize anything - it still calls super().__setattr__(name, value) after emitting the warning. This means every attribute assignment still goes through the slow PyTorch Module path, defeating the purpose of the optimization.
The warning is also emitted on EVERY attribute assignment after initialization, which would be extremely noisy and unhelpful for debugging.
Additionally, there are multiple places in the codebase that still use direct attribute assignment after initialization:
- Line 965:
self.tp_group = tp_groupinset_tensor_parallel_group() - Line 966:
self.tp_group_initialized = True - Lines 1558, 1559, 1565, 1581, 1608: Various debug-related attribute assignments
Since pytest.ini now treats RuntimeWarnings as errors, all of these will cause test failures.
Recommendation: Either:
- Remove the
__setattr__override entirely and rely on developers remembering to usefast_setattr, or - Update all the remaining direct assignments to use
fast_setattrormodule_setattr
Additional Comments (5)
|
Description
This PR includes a few performance optimizations targeting the CPU overhead. The code, perf numbers etc. are WIP. The code gets kind of ugly though :-(.
For the prepare_forward changes I did not touch attention (@cyanguwa FYI) since it has multiple exit points from the forward and was worried that I would miss something there - it would be great if we could refactor that part first to have a single return statement instead.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: