Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion examples/llm_sparsity/attention_sparsity/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Attention Sparsity for HuggingFace Models

In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation.
In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Two methods are supported:

- **Skip-Softmax**: Threshold-based skipping of near-zero attention scores during softmax (requires `attn_implementation="eager"`)
- **Sparse24 Triton**: Fine-grained 2:4 sparsity on attention scores via a fused Triton kernel with autograd support (uses `attn_implementation="modelopt_triton"`)

## Getting Started

Expand Down Expand Up @@ -159,6 +162,82 @@ custom_config = {
model = mtsa.sparsify(model, config=custom_config)
```

## Fine-grained 2:4 Sparse Attention

In addition to skip-softmax, Model Optimizer supports **fine-grained 2:4 sparsity** on attention scores via a fused Triton kernel. For every 4 attention scores along the key dimension, the kernel keeps only the top 2 and zeros out the rest — achieving 50% fixed sparsity with no calibration needed.

### Quick Example

```python
import modelopt.torch.sparsity.attention_sparsity as mtsa
from modelopt.torch.sparsity.attention_sparsity.config import SPARSE24_TRITON

model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
torch_dtype=torch.bfloat16,
)

model = mtsa.sparsify(model, config=SPARSE24_TRITON)
```

> [!Note]
> Unlike skip-softmax, sparse24 does **not** require `attn_implementation="eager"`. The `mtsa.sparsify` call automatically registers the Triton kernel as `attn_implementation="modelopt_triton"`.

### Running via Command Line

```bash
python hf_sa.py \
--pyt_ckpt_path meta-llama/Llama-3.1-8B \
--sparse_attn sparse24_triton \
--backend triton
Comment on lines +191 to +192
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The README instructs running hf_sa.py with --backend triton, but the script currently does not use args.backend at all (backend is implicitly determined by --sparse_attn via the selected config). This makes the documented CLI invocation misleading unless the script is updated to honor --backend.

Suggested change
--sparse_attn sparse24_triton \
--backend triton
--sparse_attn sparse24_triton

Copilot uses AI. Check for mistakes.
```

### Key Differences from Skip-Softmax

| | Skip-Softmax | Sparse24 Triton |
|---|---|---|
| Method | Threshold-based softmax skipping | 2:4 structured sparsity on attention scores |
| Attention backend | `eager` (patches `F.softmax`) | `modelopt_triton` (fused Triton kernel) |
| Calibration | Optional (RULER-based) | Not needed (fixed top-2-of-4 selection) |
| Sparsity ratio | Variable (depends on threshold) | Fixed 50% |
| Diagonal preservation | N/A | Yes (tiles near the causal diagonal are kept dense) |
| Training support | No | Yes (autograd-compatible forward/backward) |
| Decode support | Yes | Yes (same kernel, `is_causal=False`) |

### Training with Sparse24 Attention

The Triton kernel supports autograd. When `requires_grad=True`, the HF integration automatically uses the backward-capable path:

```python
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", torch_dtype=torch.bfloat16)
model = mtsa.sparsify(model, config=SPARSE24_TRITON)
model.train()

# Gradients flow through the sparse attention
output = model(input_ids=ids, labels=labels)
output.loss.backward() # dQ, dK, dV computed via Triton backward kernels
```

### Custom Sparse24 Configuration

```python
custom_config = {
"sparse_cfg": {
"*attn*": {
"method": "sparse24_triton",
"backend": "triton",
"skip_diagonal_blocks": True, # Keep diagonal tiles dense (recommended)
"enable": True,
},
"default": {"enable": False},
},
}

model = mtsa.sparsify(model, config=custom_config)
```

Set `skip_diagonal_blocks: False` to apply 2:4 sparsity to all tiles including the diagonal (more aggressive but may hurt quality for local attention patterns).

## References

- [Model Optimizer Documentation](https://nvidia.github.io/Model-Optimizer/)
Expand Down
16 changes: 10 additions & 6 deletions examples/llm_sparsity/attention_sparsity/hf_sa.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to update anything in example readme or changelog?

Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from modelopt.torch.sparsity.attention_sparsity.config import (
SKIP_SOFTMAX_CALIB,
SKIP_SOFTMAX_DEFAULT,
SPARSE24_TRITON,
)
from modelopt.torch.utils.memory_monitor import launch_memory_monitor

Expand All @@ -43,6 +44,7 @@
SPARSE_ATTN_CFG_CHOICES = {
"skip_softmax": SKIP_SOFTMAX_DEFAULT,
"skip_softmax_calib": SKIP_SOFTMAX_CALIB,
"sparse24_triton": SPARSE24_TRITON,
}


Expand Down Expand Up @@ -144,12 +146,14 @@ def main(args):

print(f"Loading model: {args.pyt_ckpt_path}")

# Load model and tokenizer
# Note: attn_implementation="eager" is required for calibration to work properly
# (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection)
# Select attn_implementation based on sparse method:
# - skip_softmax methods require "eager" (softmax patching bypassed by flash/sdpa)
# - sparse24_triton requires "modelopt_triton" (fused Triton kernel)
# No need to specify attn_implementation here — mtsa.sparsify() handles it
# automatically based on the sparse config (sets "modelopt_triton" for triton
# backend, keeps "eager" for pytorch backend).
model = AutoModelForCausalLM.from_pretrained(
args.pyt_ckpt_path,
attn_implementation="eager",
torch_dtype=torch.bfloat16,
)
Comment on lines +149 to 158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Before/after comparison uses different attention backends for flash_skip_softmax.

Before sparsify() the model runs with whatever attn_implementation was selected at load time (likely "sdpa"); after sparsify() validate_eager_attention forces "eager". Any output difference now conflates sparsity effects with the SDPA → eager backend switch. For the sparse24_triton path this is less of a concern, but the skip_softmax path should still load with a consistent backend for a meaningful comparison.

Consider documenting this limitation in the comment block at lines 149-154, or conditionally set attn_implementation="eager" when the config uses a pytorch backend:

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_sparsity/attention_sparsity/hf_sa.py` around lines 149 - 158,
The before/after comparison is invalid because the model is loaded with the
default attention backend (e.g., "sdpa") via
AutoModelForCausalLM.from_pretrained but after mtsa.sparsify()
validate_eager_attention forces "eager", so differences mix backend changes with
sparsity effects; fix by either (1) explicitly passing
attn_implementation="eager" into AutoModelForCausalLM.from_pretrained when the
sparse config indicates a PyTorch backend/flash_skip_softmax path (detect via
the sparsity config or args), or (2) add a clear comment in the block around
AutoModelForCausalLM.from_pretrained / mtsa.sparsify() documenting this
limitation and that comparisons for flash_skip_softmax should load with
attn_implementation set to "eager" to ensure a fair baseline.

tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
Expand Down Expand Up @@ -246,8 +250,8 @@ def main(args):
"--backend",
type=str,
default="pytorch",
choices=["pytorch"],
help="Backend for sparse attention (default: pytorch). More backends coming soon.",
choices=["pytorch", "triton"],
help="Backend for sparse attention (default: pytorch). Use 'triton' with sparse24_triton.",
)

# Sequence length arguments
Expand Down
37 changes: 31 additions & 6 deletions modelopt/torch/sparsity/attention_sparsity/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig):
title="Backend implementation.",
description=(
"Backend to use for sparse attention computation. "
"Only 'pytorch' is supported, which uses softmax patching with F.softmax. "
"Requires model to be loaded with attn_implementation='eager'."
"'pytorch' uses softmax patching with F.softmax (requires attn_implementation='eager'). "
"'triton' uses the fused Triton kernel (requires attn_implementation='modelopt_triton')."
),
)

Expand All @@ -89,10 +89,20 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig):
description=(
"Whether the model uses causal (autoregressive) attention. "
"If True, sparsity statistics are calculated over the lower triangle only. "
"Set to False for cross-attention models. "
"Defaults to True for decoder-only models like GPT, LLaMA, etc."
),
)

skip_diagonal_blocks: bool = ModeloptField(
default=True,
title="Skip diagonal blocks.",
description=(
"When True, keep diagonal tiles dense for 2:4 sparse attention. "
"Only used by sparse24_triton method. Defaults to True."
),
)

@field_validator("method")
@classmethod
def validate_method(cls, v):
Expand All @@ -104,11 +114,12 @@ def validate_method(cls, v):
@field_validator("backend")
@classmethod
def validate_backend(cls, v):
"""Validate backend is pytorch."""
if v != "pytorch":
"""Validate backend is pytorch or triton."""
if v not in ("pytorch", "triton"):
raise ValueError(
f"Invalid backend: {v}. Only 'pytorch' backend is supported. "
f"Model must be loaded with attn_implementation='eager'."
f"Invalid backend: {v}. Supported backends: 'pytorch' (requires "
f"attn_implementation='eager'), 'triton' (requires "
f"attn_implementation='modelopt_triton')."
)
return v

Expand Down Expand Up @@ -416,10 +427,24 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig):
},
}

# 2:4 structured sparsity via Triton prefill kernel (prefill-only)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Comment says "prefill-only" but the kernel supports both prefill and decode.

The PR description explicitly states the unified Triton kernel supports both prefill (2D kernel) and decode (3D kernel) paths with paged KV cache. The comment at line 429 is inaccurate and should be corrected to avoid misleading users.

📝 Proposed fix
-# 2:4 structured sparsity via Triton prefill kernel (prefill-only)
+# 2:4 structured sparsity via Triton unified attention kernel (prefill + decode)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# 2:4 structured sparsity via Triton prefill kernel (prefill-only)
# 2:4 structured sparsity via Triton unified attention kernel (prefill + decode)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` at line 429, Update the
inaccurate comment string "# 2:4 structured sparsity via Triton prefill kernel
(prefill-only)" to indicate the Triton kernel supports both prefill (2D) and
decode (3D) paths with the paged KV cache; locate the comment in the attention
sparsity config where "# 2:4 structured sparsity via Triton prefill kernel
(prefill-only)" appears and change it to something like "# 2:4 structured
sparsity via unified Triton kernel (supports prefill 2D and decode 3D with paged
KV cache)" so it correctly documents the kernel capabilities.

SPARSE24_TRITON = {
"sparse_cfg": {
"*attn*": {
"method": "sparse24_triton",
"backend": "triton",
"skip_diagonal_blocks": True,
"enable": True,
},
"default": {"enable": False},
},
}


__all__ = [
"SKIP_SOFTMAX_CALIB",
"SKIP_SOFTMAX_DEFAULT",
"SPARSE24_TRITON",
"CalibrationConfig",
"FlashSkipSoftmaxConfig",
"SparseAttentionAttributeConfig",
Expand Down
34 changes: 34 additions & 0 deletions modelopt/torch/sparsity/attention_sparsity/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,37 @@
from .utils import get_named_sparse_attention_modules, get_sparse_attention_modules


def _register_triton_backend_if_needed(model: nn.Module, config: SparseAttentionConfig) -> None:
"""Register the Triton attention backend and set attn_implementation if needed.

When the config uses ``backend="triton"``, this function:
1. Registers the Triton kernel with HF's ``ALL_ATTENTION_FUNCTIONS``.
2. Sets ``model.config._attn_implementation = "modelopt_triton"`` so the
model dispatches to the Triton kernel at forward time.

This is called automatically during ``mtsa.sparsify()`` so users never need
to manually call ``register_triton_attention()`` or set ``attn_implementation``.
"""
sparse_cfg = config.sparse_cfg if hasattr(config, "sparse_cfg") else {}
needs_triton = any(
isinstance(v, dict) and v.get("backend") == "triton" for v in sparse_cfg.values()
)
if not needs_triton:
return

from .kernels import register_triton_attention

if register_triton_attention is not None:
register_triton_attention()

# Set attn_implementation on the model so HF dispatches to the Triton kernel.
# HF's ALL_ATTENTION_FUNCTIONS is checked at forward time, not construction time,
# so this works even after the model is already loaded.
model_config = getattr(model, "config", None)
if model_config is not None:
model_config._attn_implementation = "modelopt_triton"


def is_attn_sparsified(model: nn.Module) -> bool:
"""Check if a model has sparse attention applied.

Expand Down Expand Up @@ -61,6 +92,9 @@ def convert_to_sparse_attention_model(
# Initialize the true module if necessary
model = model.init_modellike() if isinstance(model, ModelLikeModule) else model

# Register Triton attention backend and set attn_implementation if needed
_register_triton_backend_if_needed(model, config)

# Apply custom model plugins
register_custom_model_plugins_on_the_fly(model)

Expand Down
56 changes: 56 additions & 0 deletions modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Triton attention kernels for sparse attention optimization."""

import torch

from modelopt.torch.utils import import_plugin

IS_AVAILABLE = False
context_attention_fwd = None
context_attention = None
register_triton_attention = None
set_sparse24 = None

if torch.cuda.is_available():
with import_plugin(
"triton",
msg_if_missing=(
"Your device is potentially capable of using the triton attention "
"kernel. Try to install triton with `pip install triton`."
),
):
from .triton_unified_attention import context_attention as _context_attention
from .triton_unified_attention import context_attention_fwd as _context_attention_fwd

context_attention_fwd = _context_attention_fwd
context_attention = _context_attention
IS_AVAILABLE = True
with import_plugin("transformers"):
from .hf_triton_attention import register_triton_attention as _register_triton_attention
from .hf_triton_attention import set_sparse24 as _set_sparse24

register_triton_attention = _register_triton_attention
set_sparse24 = _set_sparse24
_register_triton_attention()

__all__ = [
"IS_AVAILABLE",
"context_attention",
"context_attention_fwd",
"register_triton_attention",
"set_sparse24",
]
Loading