Conversation
Signed-off-by: Kai Xu <kaix@nvidia.com>
📝 WalkthroughWalkthroughThis pull request introduces a new Triton-based sparse attention backend supporting 2:4 structured sparsity for efficient LLM inference. It adds Triton kernels, a sparse attention method, configuration options, HuggingFace integration, and comprehensive GPU tests to enable high-performance sparse prefill and decode paths. Changes
Sequence Diagram(s)sequenceDiagram
participant User as User/CLI
participant Config as Config/Conversion
participant Kernel as Kernel Registration
participant Model as HF Model
participant Forward as Forward Pass
User->>Config: select sparse24_triton backend
activate Config
Config->>Config: validate backend="triton"
Config->>Kernel: _register_triton_backend_if_needed()
deactivate Config
activate Kernel
Kernel->>Kernel: register_triton_attention()
Kernel->>Model: set attn_implementation="modelopt_triton"
Kernel->>Model: patch attention interface
deactivate Kernel
User->>Model: forward pass (prefill/decode)
activate Model
Model->>Forward: dispatch via AttentionInterface
deactivate Model
activate Forward
Forward->>Forward: get_sparse_context() from Sparse24Triton
Forward->>Forward: apply 2:4 sparsity mask
Forward->>Forward: unified_attention (2D/3D path)
Forward-->>Forward: return sparse attention output
deactivate Forward
Forward-->>User: sparse attention result
sequenceDiagram
participant Scores as Attention Scores
participant Method as Sparse24Triton Method
participant Mask as Sparsity Mask
participant Triton as Triton Kernel
Scores->>Method: calculate_sparsity(scores)
activate Method
Method->>Mask: _sparse24_mask_along_last_dim()
activate Mask
Mask->>Mask: select top-2 per 4-group
Mask-->>Method: binary sparsity mask
deactivate Mask
Method-->>Scores: stats dict (phase, counts)
deactivate Method
Scores->>Method: apply_sparsity(scores, mask)
activate Method
Method->>Scores: mask * scores
Method-->>Scores: sparse scores
deactivate Method
Scores->>Triton: context_attention_fwd(q, k, v, mask)
activate Triton
Triton->>Triton: detect packed vs. unpacked layout
alt packed layout
Triton->>Triton: derive seq_lens, pack tensors
Triton->>Triton: unified_attention (2D kernel)
Triton->>Triton: unpack output
else unpacked layout
Triton->>Triton: reshape for segment layout
Triton->>Triton: unified_attention (3D kernel + buffers)
Triton->>Triton: reshape back
end
Triton-->>Scores: output attention
deactivate Triton
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/llm_sparsity/attention_sparsity/hf_sa.py (1)
249-255:⚠️ Potential issue | 🟡 Minor
--backendargument is parsed but never used.
args.backendis not referenced anywhere inmain()or the helper functions; the backend is already embedded in eachSPARSE_ATTN_CFG_CHOICESentry. Either wire it up to override thebackendkey in the resolved config, or remove the argument.🛡️ Minimal fix to either wire it up or remove dead code
Option A — wire it up to override the config backend:
sparse_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn] + # Override backend from CLI if explicitly specified + sparse_config = copy.deepcopy(sparse_config) + for key, cfg in sparse_config.get("sparse_cfg", {}).items(): + if isinstance(cfg, dict) and "backend" in cfg: + cfg["backend"] = args.backend + # Override calibration options if provided via CLIOption B — remove the unused argument:
- parser.add_argument( - "--backend", - type=str, - default="pytorch", - choices=["pytorch", "triton"], - help="Backend for sparse attention (default: pytorch). Use 'triton' with sparse24_triton.", - )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/llm_sparsity/attention_sparsity/hf_sa.py` around lines 249 - 255, The --backend arg added via parser.add_argument is parsed into args.backend but never used; either remove that parser entry or propagate args.backend into the sparse-attention config before it is used. To fix, in main() (where args is available and the config is selected from SPARSE_ATTN_CFG_CHOICES and/or built into a resolved_cfg) overwrite the backend key with args.backend (e.g., set resolved_cfg["backend"] = args.backend) so the chosen backend actually takes effect, or remove the parser.add_argument("--backend", ...) line to eliminate the dead argument.
🧹 Nitpick comments (2)
pyproject.toml (1)
84-84: Consider addingE731for consistency with sibling Triton entries.The entries for
modelopt/torch/quantization/triton/*(line 83) andexamples/deepseek/ds_kernel.py(line 85) both includeE731in their suppression rules, while the new entry omits it. While no lambda expressions currently exist in the kernel files, addingE731maintains consistency across the pattern established for Triton-based code and future-proofs against potential additions.🔧 Proposed fix for consistency
-"modelopt/torch/sparsity/attention_sparsity/kernels/*" = ["N803", "N806"] # triton kernel style +"modelopt/torch/sparsity/attention_sparsity/kernels/*" = ["N803", "N806", "E731"] # triton kernel style🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pyproject.toml` at line 84, Add the missing E731 suppression to the Triton kernel entry "modelopt/torch/sparsity/attention_sparsity/kernels/*" so it matches the sibling Triton entries; update the rule that currently lists ["N803", "N806"] to include "E731" as well to maintain consistency and future-proof against lambda usage in the kernels.modelopt/torch/sparsity/attention_sparsity/methods/registry.py (1)
73-83:get_sparse_contextshould be@abstractmethod.
SparseAttentionModule.forward()callsget_sparse_context()unconditionally — every registered method must implement it. All peer methods (calculate_sparsity,apply_sparsity,name) are@abstractmethod; this should be too. Without the decorator, a subclass can be instantiated without implementing it, silently deferring theNotImplementedErrorto runtime.♻️ Proposed fix
- def get_sparse_context(self, module: torch.nn.Module): - """Return a context manager that activates this method's sparsity during forward. - - Each method subclass implements its own activation mechanism: - - Softmax-patching methods replace F.softmax during the forward pass. - - Kernel-fused methods set flags on ``module`` that the kernel reads. - - Args: - module: The SparseAttentionModule wrapping the attention layer. - """ - raise NotImplementedError(f"{type(self).__name__} must implement get_sparse_context()") + `@abstractmethod` + def get_sparse_context(self, module: torch.nn.Module): + """Return a context manager that activates this method's sparsity during forward. + + Each method subclass implements its own activation mechanism: + - Softmax-patching methods replace F.softmax during the forward pass. + - Kernel-fused methods set flags on ``module`` that the kernel reads. + + Args: + module: The SparseAttentionModule wrapping the attention layer. + """🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/methods/registry.py` around lines 73 - 83, Mark get_sparse_context as an abstract method like the other abstract APIs so subclasses cannot be instantiated without implementing it: add the `@abstractmethod` decorator above the get_sparse_context method in the registry (matching how calculate_sparsity, apply_sparsity, and name are decorated) and ensure abstractmethod is imported if not already; this enforces that any subclass used by SparseAttentionModule.forward() must implement get_sparse_context.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/llm_sparsity/attention_sparsity/hf_sa.py`:
- Around line 149-158: The before/after comparison is invalid because the model
is loaded with the default attention backend (e.g., "sdpa") via
AutoModelForCausalLM.from_pretrained but after mtsa.sparsify()
validate_eager_attention forces "eager", so differences mix backend changes with
sparsity effects; fix by either (1) explicitly passing
attn_implementation="eager" into AutoModelForCausalLM.from_pretrained when the
sparse config indicates a PyTorch backend/flash_skip_softmax path (detect via
the sparsity config or args), or (2) add a clear comment in the block around
AutoModelForCausalLM.from_pretrained / mtsa.sparsify() documenting this
limitation and that comparisons for flash_skip_softmax should load with
attn_implementation set to "eager" to ensure a fair baseline.
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Line 429: Update the inaccurate comment string "# 2:4 structured sparsity via
Triton prefill kernel (prefill-only)" to indicate the Triton kernel supports
both prefill (2D) and decode (3D) paths with the paged KV cache; locate the
comment in the attention sparsity config where "# 2:4 structured sparsity via
Triton prefill kernel (prefill-only)" appears and change it to something like "#
2:4 structured sparsity via unified Triton kernel (supports prefill 2D and
decode 3D with paged KV cache)" so it correctly documents the kernel
capabilities.
In `@modelopt/torch/sparsity/attention_sparsity/methods/sparse24_triton.py`:
- Around line 142-150: get_sparse_context currently sets module._apply_sparse24
and module._skip_diagonal_blocks but only resets _apply_sparse24 on exit;
preserve and restore the original _skip_diagonal_blocks value to avoid mutating
module state. Fix get_sparse_context by reading the original value (use getattr
to handle missing attribute), set module._skip_diagonal_blocks =
self.skip_diagonal_blocks on entry, and in the finally block restore the
original value (use setattr or delattr if the attribute did not exist
originally). Ensure you still clear _apply_sparse24 as before and reference
get_sparse_context, module._apply_sparse24, module._skip_diagonal_blocks, and
self.skip_diagonal_blocks when implementing the change.
---
Outside diff comments:
In `@examples/llm_sparsity/attention_sparsity/hf_sa.py`:
- Around line 249-255: The --backend arg added via parser.add_argument is parsed
into args.backend but never used; either remove that parser entry or propagate
args.backend into the sparse-attention config before it is used. To fix, in
main() (where args is available and the config is selected from
SPARSE_ATTN_CFG_CHOICES and/or built into a resolved_cfg) overwrite the backend
key with args.backend (e.g., set resolved_cfg["backend"] = args.backend) so the
chosen backend actually takes effect, or remove the
parser.add_argument("--backend", ...) line to eliminate the dead argument.
---
Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/methods/registry.py`:
- Around line 73-83: Mark get_sparse_context as an abstract method like the
other abstract APIs so subclasses cannot be instantiated without implementing
it: add the `@abstractmethod` decorator above the get_sparse_context method in the
registry (matching how calculate_sparsity, apply_sparsity, and name are
decorated) and ensure abstractmethod is imported if not already; this enforces
that any subclass used by SparseAttentionModule.forward() must implement
get_sparse_context.
In `@pyproject.toml`:
- Line 84: Add the missing E731 suppression to the Triton kernel entry
"modelopt/torch/sparsity/attention_sparsity/kernels/*" so it matches the sibling
Triton entries; update the rule that currently lists ["N803", "N806"] to include
"E731" as well to maintain consistency and future-proof against lambda usage in
the kernels.
| # Select attn_implementation based on sparse method: | ||
| # - skip_softmax methods require "eager" (softmax patching bypassed by flash/sdpa) | ||
| # - sparse24_triton requires "modelopt_triton" (fused Triton kernel) | ||
| # No need to specify attn_implementation here — mtsa.sparsify() handles it | ||
| # automatically based on the sparse config (sets "modelopt_triton" for triton | ||
| # backend, keeps "eager" for pytorch backend). | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| args.pyt_ckpt_path, | ||
| attn_implementation="eager", | ||
| torch_dtype=torch.bfloat16, | ||
| ) |
There was a problem hiding this comment.
Before/after comparison uses different attention backends for flash_skip_softmax.
Before sparsify() the model runs with whatever attn_implementation was selected at load time (likely "sdpa"); after sparsify() validate_eager_attention forces "eager". Any output difference now conflates sparsity effects with the SDPA → eager backend switch. For the sparse24_triton path this is less of a concern, but the skip_softmax path should still load with a consistent backend for a meaningful comparison.
Consider documenting this limitation in the comment block at lines 149-154, or conditionally set attn_implementation="eager" when the config uses a pytorch backend:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/llm_sparsity/attention_sparsity/hf_sa.py` around lines 149 - 158,
The before/after comparison is invalid because the model is loaded with the
default attention backend (e.g., "sdpa") via
AutoModelForCausalLM.from_pretrained but after mtsa.sparsify()
validate_eager_attention forces "eager", so differences mix backend changes with
sparsity effects; fix by either (1) explicitly passing
attn_implementation="eager" into AutoModelForCausalLM.from_pretrained when the
sparse config indicates a PyTorch backend/flash_skip_softmax path (detect via
the sparsity config or args), or (2) add a clear comment in the block around
AutoModelForCausalLM.from_pretrained / mtsa.sparsify() documenting this
limitation and that comparisons for flash_skip_softmax should load with
attn_implementation set to "eager" to ensure a fair baseline.
| }, | ||
| } | ||
|
|
||
| # 2:4 structured sparsity via Triton prefill kernel (prefill-only) |
There was a problem hiding this comment.
Comment says "prefill-only" but the kernel supports both prefill and decode.
The PR description explicitly states the unified Triton kernel supports both prefill (2D kernel) and decode (3D kernel) paths with paged KV cache. The comment at line 429 is inaccurate and should be corrected to avoid misleading users.
📝 Proposed fix
-# 2:4 structured sparsity via Triton prefill kernel (prefill-only)
+# 2:4 structured sparsity via Triton unified attention kernel (prefill + decode)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # 2:4 structured sparsity via Triton prefill kernel (prefill-only) | |
| # 2:4 structured sparsity via Triton unified attention kernel (prefill + decode) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/config.py` at line 429, Update the
inaccurate comment string "# 2:4 structured sparsity via Triton prefill kernel
(prefill-only)" to indicate the Triton kernel supports both prefill (2D) and
decode (3D) paths with the paged KV cache; locate the comment in the attention
sparsity config where "# 2:4 structured sparsity via Triton prefill kernel
(prefill-only)" appears and change it to something like "# 2:4 structured
sparsity via unified Triton kernel (supports prefill 2D and decode 3D with paged
KV cache)" so it correctly documents the kernel capabilities.
| @contextlib.contextmanager | ||
| def get_sparse_context(self, module: torch.nn.Module): | ||
| """Set _apply_sparse24 and _skip_diagonal_blocks on module for the Triton kernel.""" | ||
| module._apply_sparse24 = True | ||
| module._skip_diagonal_blocks = self.skip_diagonal_blocks | ||
| try: | ||
| yield | ||
| finally: | ||
| module._apply_sparse24 = False |
There was a problem hiding this comment.
_skip_diagonal_blocks is not restored in the finally block.
get_sparse_context sets both _apply_sparse24 and _skip_diagonal_blocks on the module but only resets _apply_sparse24 on exit. If a module had a different _skip_diagonal_blocks value before entering the context, it will be silently overwritten.
Proposed fix
`@contextlib.contextmanager`
def get_sparse_context(self, module: torch.nn.Module):
"""Set _apply_sparse24 and _skip_diagonal_blocks on module for the Triton kernel."""
+ prev_sparse24 = getattr(module, "_apply_sparse24", False)
+ prev_skip_diag = getattr(module, "_skip_diagonal_blocks", True)
module._apply_sparse24 = True
module._skip_diagonal_blocks = self.skip_diagonal_blocks
try:
yield
finally:
- module._apply_sparse24 = False
+ module._apply_sparse24 = prev_sparse24
+ module._skip_diagonal_blocks = prev_skip_diag📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @contextlib.contextmanager | |
| def get_sparse_context(self, module: torch.nn.Module): | |
| """Set _apply_sparse24 and _skip_diagonal_blocks on module for the Triton kernel.""" | |
| module._apply_sparse24 = True | |
| module._skip_diagonal_blocks = self.skip_diagonal_blocks | |
| try: | |
| yield | |
| finally: | |
| module._apply_sparse24 = False | |
| `@contextlib.contextmanager` | |
| def get_sparse_context(self, module: torch.nn.Module): | |
| """Set _apply_sparse24 and _skip_diagonal_blocks on module for the Triton kernel.""" | |
| prev_sparse24 = getattr(module, "_apply_sparse24", False) | |
| prev_skip_diag = getattr(module, "_skip_diagonal_blocks", True) | |
| module._apply_sparse24 = True | |
| module._skip_diagonal_blocks = self.skip_diagonal_blocks | |
| try: | |
| yield | |
| finally: | |
| module._apply_sparse24 = prev_sparse24 | |
| module._skip_diagonal_blocks = prev_skip_diag |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/methods/sparse24_triton.py` around
lines 142 - 150, get_sparse_context currently sets module._apply_sparse24 and
module._skip_diagonal_blocks but only resets _apply_sparse24 on exit; preserve
and restore the original _skip_diagonal_blocks value to avoid mutating module
state. Fix get_sparse_context by reading the original value (use getattr to
handle missing attribute), set module._skip_diagonal_blocks =
self.skip_diagonal_blocks on entry, and in the finally block restore the
original value (use setattr or delattr if the attribute did not exist
originally). Ensure you still clear _apply_sparse24 as before and reference
get_sparse_context, module._apply_sparse24, module._skip_diagonal_blocks, and
self.skip_diagonal_blocks when implementing the change.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #916 +/- ##
==========================================
- Coverage 73.11% 72.98% -0.13%
==========================================
Files 205 206 +1
Lines 22281 22347 +66
==========================================
+ Hits 16291 16311 +20
- Misses 5990 6036 +46 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
What does this PR do?
Type of change: ?
New feature
Overview: ?
triton_unified_attention.py) for both prefill (2D kernel) and decode (3D kernel), with paged KV cache support.attn_implementation="modelopt_triton"that automatically set bymtsa.sparsify().Usage
Testing
tests/gpu/torch/sparsity/attention_sparsity/test_triton_unified_attention.py
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Tests