From 30fa3dad3aeaafc39decd930b2f68ff447e7721f Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Sat, 21 Feb 2026 16:15:36 -0800 Subject: [PATCH] Add 2:4 sparse attention kernel Signed-off-by: Kai Xu --- .../llm_sparsity/attention_sparsity/README.md | 81 +- .../llm_sparsity/attention_sparsity/hf_sa.py | 16 +- .../sparsity/attention_sparsity/config.py | 37 +- .../sparsity/attention_sparsity/conversion.py | 34 + .../attention_sparsity/kernels/__init__.py | 56 + .../kernels/hf_triton_attention.py | 374 ++++++ .../kernels/triton_unified_attention.py | 1011 +++++++++++++++++ .../attention_sparsity/methods/__init__.py | 2 +- .../methods/flash_skip_softmax.py | 16 + .../attention_sparsity/methods/registry.py | 12 + .../methods/sparse24_triton.py | 161 +++ .../attention_sparsity/plugins/huggingface.py | 23 +- .../attention_sparsity/sparse_attention.py | 53 +- pyproject.toml | 1 + .../test_triton_unified_attention.py | 862 ++++++++++++++ 15 files changed, 2676 insertions(+), 63 deletions(-) create mode 100644 modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/kernels/hf_triton_attention.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/kernels/triton_unified_attention.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/methods/sparse24_triton.py create mode 100644 tests/gpu/torch/sparsity/attention_sparsity/test_triton_unified_attention.py diff --git a/examples/llm_sparsity/attention_sparsity/README.md b/examples/llm_sparsity/attention_sparsity/README.md index e9d50ae10..a69187232 100644 --- a/examples/llm_sparsity/attention_sparsity/README.md +++ b/examples/llm_sparsity/attention_sparsity/README.md @@ -1,6 +1,9 @@ # Attention Sparsity for HuggingFace Models -In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation. +In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Two methods are supported: + +- **Skip-Softmax**: Threshold-based skipping of near-zero attention scores during softmax (requires `attn_implementation="eager"`) +- **Sparse24 Triton**: Fine-grained 2:4 sparsity on attention scores via a fused Triton kernel with autograd support (uses `attn_implementation="modelopt_triton"`) ## Getting Started @@ -159,6 +162,82 @@ custom_config = { model = mtsa.sparsify(model, config=custom_config) ``` +## Fine-grained 2:4 Sparse Attention + +In addition to skip-softmax, Model Optimizer supports **fine-grained 2:4 sparsity** on attention scores via a fused Triton kernel. For every 4 attention scores along the key dimension, the kernel keeps only the top 2 and zeros out the rest — achieving 50% fixed sparsity with no calibration needed. + +### Quick Example + +```python +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.config import SPARSE24_TRITON + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.1-8B", + torch_dtype=torch.bfloat16, +) + +model = mtsa.sparsify(model, config=SPARSE24_TRITON) +``` + +> [!Note] +> Unlike skip-softmax, sparse24 does **not** require `attn_implementation="eager"`. The `mtsa.sparsify` call automatically registers the Triton kernel as `attn_implementation="modelopt_triton"`. + +### Running via Command Line + +```bash +python hf_sa.py \ + --pyt_ckpt_path meta-llama/Llama-3.1-8B \ + --sparse_attn sparse24_triton \ + --backend triton +``` + +### Key Differences from Skip-Softmax + +| | Skip-Softmax | Sparse24 Triton | +|---|---|---| +| Method | Threshold-based softmax skipping | 2:4 structured sparsity on attention scores | +| Attention backend | `eager` (patches `F.softmax`) | `modelopt_triton` (fused Triton kernel) | +| Calibration | Optional (RULER-based) | Not needed (fixed top-2-of-4 selection) | +| Sparsity ratio | Variable (depends on threshold) | Fixed 50% | +| Diagonal preservation | N/A | Yes (tiles near the causal diagonal are kept dense) | +| Training support | No | Yes (autograd-compatible forward/backward) | +| Decode support | Yes | Yes (same kernel, `is_causal=False`) | + +### Training with Sparse24 Attention + +The Triton kernel supports autograd. When `requires_grad=True`, the HF integration automatically uses the backward-capable path: + +```python +model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", torch_dtype=torch.bfloat16) +model = mtsa.sparsify(model, config=SPARSE24_TRITON) +model.train() + +# Gradients flow through the sparse attention +output = model(input_ids=ids, labels=labels) +output.loss.backward() # dQ, dK, dV computed via Triton backward kernels +``` + +### Custom Sparse24 Configuration + +```python +custom_config = { + "sparse_cfg": { + "*attn*": { + "method": "sparse24_triton", + "backend": "triton", + "skip_diagonal_blocks": True, # Keep diagonal tiles dense (recommended) + "enable": True, + }, + "default": {"enable": False}, + }, +} + +model = mtsa.sparsify(model, config=custom_config) +``` + +Set `skip_diagonal_blocks: False` to apply 2:4 sparsity to all tiles including the diagonal (more aggressive but may hurt quality for local attention patterns). + ## References - [Model Optimizer Documentation](https://nvidia.github.io/Model-Optimizer/) diff --git a/examples/llm_sparsity/attention_sparsity/hf_sa.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py index 74c5e9a54..9a8c694fc 100644 --- a/examples/llm_sparsity/attention_sparsity/hf_sa.py +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -31,6 +31,7 @@ from modelopt.torch.sparsity.attention_sparsity.config import ( SKIP_SOFTMAX_CALIB, SKIP_SOFTMAX_DEFAULT, + SPARSE24_TRITON, ) from modelopt.torch.utils.memory_monitor import launch_memory_monitor @@ -43,6 +44,7 @@ SPARSE_ATTN_CFG_CHOICES = { "skip_softmax": SKIP_SOFTMAX_DEFAULT, "skip_softmax_calib": SKIP_SOFTMAX_CALIB, + "sparse24_triton": SPARSE24_TRITON, } @@ -144,12 +146,14 @@ def main(args): print(f"Loading model: {args.pyt_ckpt_path}") - # Load model and tokenizer - # Note: attn_implementation="eager" is required for calibration to work properly - # (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection) + # Select attn_implementation based on sparse method: + # - skip_softmax methods require "eager" (softmax patching bypassed by flash/sdpa) + # - sparse24_triton requires "modelopt_triton" (fused Triton kernel) + # No need to specify attn_implementation here — mtsa.sparsify() handles it + # automatically based on the sparse config (sets "modelopt_triton" for triton + # backend, keeps "eager" for pytorch backend). model = AutoModelForCausalLM.from_pretrained( args.pyt_ckpt_path, - attn_implementation="eager", torch_dtype=torch.bfloat16, ) tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path) @@ -246,8 +250,8 @@ def main(args): "--backend", type=str, default="pytorch", - choices=["pytorch"], - help="Backend for sparse attention (default: pytorch). More backends coming soon.", + choices=["pytorch", "triton"], + help="Backend for sparse attention (default: pytorch). Use 'triton' with sparse24_triton.", ) # Sequence length arguments diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index d2d3b1078..e178594ca 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -72,8 +72,8 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): title="Backend implementation.", description=( "Backend to use for sparse attention computation. " - "Only 'pytorch' is supported, which uses softmax patching with F.softmax. " - "Requires model to be loaded with attn_implementation='eager'." + "'pytorch' uses softmax patching with F.softmax (requires attn_implementation='eager'). " + "'triton' uses the fused Triton kernel (requires attn_implementation='modelopt_triton')." ), ) @@ -89,10 +89,20 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): description=( "Whether the model uses causal (autoregressive) attention. " "If True, sparsity statistics are calculated over the lower triangle only. " + "Set to False for cross-attention models. " "Defaults to True for decoder-only models like GPT, LLaMA, etc." ), ) + skip_diagonal_blocks: bool = ModeloptField( + default=True, + title="Skip diagonal blocks.", + description=( + "When True, keep diagonal tiles dense for 2:4 sparse attention. " + "Only used by sparse24_triton method. Defaults to True." + ), + ) + @field_validator("method") @classmethod def validate_method(cls, v): @@ -104,11 +114,12 @@ def validate_method(cls, v): @field_validator("backend") @classmethod def validate_backend(cls, v): - """Validate backend is pytorch.""" - if v != "pytorch": + """Validate backend is pytorch or triton.""" + if v not in ("pytorch", "triton"): raise ValueError( - f"Invalid backend: {v}. Only 'pytorch' backend is supported. " - f"Model must be loaded with attn_implementation='eager'." + f"Invalid backend: {v}. Supported backends: 'pytorch' (requires " + f"attn_implementation='eager'), 'triton' (requires " + f"attn_implementation='modelopt_triton')." ) return v @@ -416,10 +427,24 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): }, } +# 2:4 structured sparsity via Triton prefill kernel (prefill-only) +SPARSE24_TRITON = { + "sparse_cfg": { + "*attn*": { + "method": "sparse24_triton", + "backend": "triton", + "skip_diagonal_blocks": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + __all__ = [ "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_DEFAULT", + "SPARSE24_TRITON", "CalibrationConfig", "FlashSkipSoftmaxConfig", "SparseAttentionAttributeConfig", diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index 2155a13d0..26fb4e08a 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -32,6 +32,37 @@ from .utils import get_named_sparse_attention_modules, get_sparse_attention_modules +def _register_triton_backend_if_needed(model: nn.Module, config: SparseAttentionConfig) -> None: + """Register the Triton attention backend and set attn_implementation if needed. + + When the config uses ``backend="triton"``, this function: + 1. Registers the Triton kernel with HF's ``ALL_ATTENTION_FUNCTIONS``. + 2. Sets ``model.config._attn_implementation = "modelopt_triton"`` so the + model dispatches to the Triton kernel at forward time. + + This is called automatically during ``mtsa.sparsify()`` so users never need + to manually call ``register_triton_attention()`` or set ``attn_implementation``. + """ + sparse_cfg = config.sparse_cfg if hasattr(config, "sparse_cfg") else {} + needs_triton = any( + isinstance(v, dict) and v.get("backend") == "triton" for v in sparse_cfg.values() + ) + if not needs_triton: + return + + from .kernels import register_triton_attention + + if register_triton_attention is not None: + register_triton_attention() + + # Set attn_implementation on the model so HF dispatches to the Triton kernel. + # HF's ALL_ATTENTION_FUNCTIONS is checked at forward time, not construction time, + # so this works even after the model is already loaded. + model_config = getattr(model, "config", None) + if model_config is not None: + model_config._attn_implementation = "modelopt_triton" + + def is_attn_sparsified(model: nn.Module) -> bool: """Check if a model has sparse attention applied. @@ -61,6 +92,9 @@ def convert_to_sparse_attention_model( # Initialize the true module if necessary model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + # Register Triton attention backend and set attn_implementation if needed + _register_triton_backend_if_needed(model, config) + # Apply custom model plugins register_custom_model_plugins_on_the_fly(model) diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py b/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py new file mode 100644 index 000000000..091b08be4 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Triton attention kernels for sparse attention optimization.""" + +import torch + +from modelopt.torch.utils import import_plugin + +IS_AVAILABLE = False +context_attention_fwd = None +context_attention = None +register_triton_attention = None +set_sparse24 = None + +if torch.cuda.is_available(): + with import_plugin( + "triton", + msg_if_missing=( + "Your device is potentially capable of using the triton attention " + "kernel. Try to install triton with `pip install triton`." + ), + ): + from .triton_unified_attention import context_attention as _context_attention + from .triton_unified_attention import context_attention_fwd as _context_attention_fwd + + context_attention_fwd = _context_attention_fwd + context_attention = _context_attention + IS_AVAILABLE = True + with import_plugin("transformers"): + from .hf_triton_attention import register_triton_attention as _register_triton_attention + from .hf_triton_attention import set_sparse24 as _set_sparse24 + + register_triton_attention = _register_triton_attention + set_sparse24 = _set_sparse24 + _register_triton_attention() + +__all__ = [ + "IS_AVAILABLE", + "context_attention", + "context_attention_fwd", + "register_triton_attention", + "set_sparse24", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/hf_triton_attention.py b/modelopt/torch/sparsity/attention_sparsity/kernels/hf_triton_attention.py new file mode 100644 index 000000000..00b330f4e --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/hf_triton_attention.py @@ -0,0 +1,374 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hugging Face attention backend for the Triton unified attention kernel. + +Registers the Triton kernel as attn_implementation="modelopt_triton" so HF models +use it natively without patching forward. Both prefill and decode use the unified +Triton kernel. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn + +from modelopt.torch.sparsity.attention_sparsity.kernels.triton_unified_attention import ( + context_attention, + context_attention_fwd, +) + + +def _attention_mask_supported_for_triton(attention_mask: torch.Tensor) -> bool: + """Return True if mask shape is supported for packing (2D [batch, seq_len]).""" + return attention_mask.dim() == 2 and attention_mask.shape[0] > 0 and attention_mask.shape[1] > 0 + + +def _packed_token_indices( + seq_lens: torch.Tensor, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute vectorized (batch_idx, token_idx) for packing/unpacking variable-length sequences. + + Assumes valid tokens occupy positions ``0..seq_lens[b]-1`` in each batch + element (right-padded layout). This matches the HF convention where padding + tokens are appended after the valid content during prefill. + + Args: + seq_lens: [batch] number of valid tokens per sequence. + device: Target device. + + Returns: + (batch_indices, token_indices) each of shape [total_valid_tokens]. + """ + total = int(seq_lens.sum().item()) + cumsum = torch.zeros(seq_lens.shape[0] + 1, device=device, dtype=torch.long) + cumsum[1:] = torch.cumsum(seq_lens, dim=0) + flat_idx = torch.arange(total, device=device, dtype=torch.long) + batch_indices = torch.bucketize(flat_idx, cumsum[1:], right=True) + token_indices = flat_idx - cumsum[batch_indices] + return batch_indices, token_indices + + +def _derive_seq_lens_and_pack( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]: + """Derive b_seq_len and b_start_loc from 2D mask; pack q,k,v to contiguous [total, heads, dim]. + + attention_mask: [batch, seq_len], 1 = valid, 0 = pad. Assumes valid tokens are + at positions 0..n-1 (right-padded layout). The count of valid tokens per row + determines the packing lengths. + Returns: (q_packed, k_packed, v_packed, b_start_loc, b_seq_len, max_input_len). + """ + batch = query.shape[0] + device = query.device + # Valid length per batch: number of ones (or non-zero) in the mask per row + if attention_mask.dtype == torch.bool: + seq_lens = attention_mask.sum(dim=1).long() + else: + seq_lens = (attention_mask != 0).sum(dim=1).long() + seq_lens = seq_lens.to(device) + b_start_loc = torch.zeros(batch + 1, device=device, dtype=torch.int32) + b_start_loc[1:] = torch.cumsum(seq_lens, dim=0) + b_start_loc = b_start_loc[:batch] + b_seq_len = seq_lens.to(torch.int32) + max_input_len = int(seq_lens.max().item()) + + # Vectorized packing: query [batch, heads, seq, dim] -> [total, heads, dim] + batch_indices, token_indices = _packed_token_indices(seq_lens, device) + q_packed = query[batch_indices, :, token_indices, :].contiguous() + k_packed = key[batch_indices, :, token_indices, :].contiguous() + v_packed = value[batch_indices, :, token_indices, :].contiguous() + return q_packed, k_packed, v_packed, b_start_loc, b_seq_len, max_input_len + + +def _unpack_attn_output( + o_packed: torch.Tensor, + batch: int, + num_heads: int, + head_dim: int, + seq_len: int, + b_seq_len: torch.Tensor, + dtype: torch.dtype, + device: torch.device, +) -> torch.Tensor: + """Scatter packed output [total_tokens, num_heads, head_dim] to [batch, seq_len, num_heads, head_dim].""" + attn_output = torch.zeros(batch, seq_len, num_heads, head_dim, device=device, dtype=dtype) + total = int(b_seq_len.sum().item()) + if total == 0: + return attn_output + batch_indices, token_indices = _packed_token_indices(b_seq_len.long(), device) + attn_output[batch_indices, token_indices] = o_packed + return attn_output + + +def _decode_attention( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, +) -> torch.Tensor: + """Decode attention via context_attention_fwd (one query token per sequence). + + Reshapes HF-format K/V from [batch, kv_heads, seq_k, dim] to flat packed + [total_kv_tokens, kv_heads, dim] and calls context_attention_fwd with + is_causal=False so the single query token attends to all K/V positions. + + Args: + module: The attention module (unused; kept for API compatibility). + query: [batch, num_heads, 1, head_dim]. + key: [batch, num_kv_heads, seq_k, head_dim]. + value: [batch, num_kv_heads, seq_k, head_dim]. + attention_mask: Optional 2D [batch, seq_k] mask; 1=valid, 0=pad. + scaling: Softmax scale. + + Returns: + attn_output: [batch, 1, num_heads, head_dim]. + """ + batch = query.shape[0] + num_kv_heads = key.shape[1] + seq_k = key.shape[2] + head_dim = query.shape[3] + device = query.device + + # Q: [batch, heads, 1, dim] -> [batch, heads, dim] (flat: 1 token per batch) + q_flat = query.squeeze(2).contiguous() + + # K/V: [batch, kv_heads, seq_k, dim] -> [batch * seq_k, kv_heads, dim] + k_flat = key.permute(0, 2, 1, 3).reshape(batch * seq_k, num_kv_heads, head_dim).contiguous() + v_flat = value.permute(0, 2, 1, 3).reshape(batch * seq_k, num_kv_heads, head_dim).contiguous() + + # Q metadata: each batch element has 1 query token + b_start_loc_q = torch.arange(batch, device=device, dtype=torch.int32) + b_seq_len_q = torch.ones(batch, device=device, dtype=torch.int32) + + # K/V metadata: each batch element has seq_k tokens (or fewer if masked) + if attention_mask is not None and _attention_mask_supported_for_triton(attention_mask): + if attention_mask.dtype == torch.bool: + b_seq_len_k = attention_mask.sum(dim=1).to(torch.int32).to(device) + else: + b_seq_len_k = (attention_mask != 0).sum(dim=1).to(torch.int32).to(device) + else: + b_seq_len_k = torch.full((batch,), seq_k, device=device, dtype=torch.int32) + + b_start_loc_k = torch.arange(batch, device=device, dtype=torch.int32) * seq_k + + o_flat = torch.empty_like(q_flat) + context_attention_fwd( + q_flat, + k_flat, + v_flat, + o_flat, + b_start_loc=b_start_loc_q, + b_seq_len=b_seq_len_q, + max_input_len=1, + is_causal=False, + softmax_scale=scaling, + b_start_loc_k=b_start_loc_k, + b_seq_len_k=b_seq_len_k, + max_input_len_k=int(b_seq_len_k.max().item()), + ) + + # [batch, heads, dim] -> [batch, 1, heads, dim] + return o_flat.unsqueeze(1) + + +def triton_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs, +) -> tuple[torch.Tensor, None]: + """Attention forward compatible with HF AttentionInterface. + + Uses the unified Triton kernel for both prefill (seq_len > 1) and decode + (seq_len == 1). Same signature as eager_attention_forward. + + Args: + module: The attention module (LlamaAttention etc.). + query: [batch, num_heads, seq_len, head_dim]. + key: [batch, num_kv_heads, seq_k, head_dim]. + value: [batch, num_kv_heads, seq_k, head_dim]. + attention_mask: Optional; kernel handles causal internally. + 2D [batch, seq_k] masks are used to derive per-sequence lengths. + Unsupported formats raise an error. + scaling: Softmax scale (e.g. 1/sqrt(head_dim)). + dropout: Ignored (kernel has no dropout); use 0 for eval. + **kwargs: May contain apply_sparse24, skip_diagonal_blocks for 2:4 sparse attention. + + Returns: + (attn_output, None) with attn_output [batch, seq_len, num_heads, head_dim]. + """ + batch, num_heads, seq_len, head_dim = query.shape + seq_k = key.shape[2] + is_cross_attention = seq_len != seq_k + + # Decode: one query token per sequence, full context in K/V + if seq_len <= 1: + attn_output = _decode_attention(module, query, key, value, attention_mask, scaling) + return (attn_output, None) + + device = query.device + num_kv_heads = key.shape[1] + is_causal = not is_cross_attention + apply_sparse24 = kwargs.get("apply_sparse24", getattr(module, "_apply_sparse24", False)) + skip_diagonal_blocks = kwargs.get( + "skip_diagonal_blocks", getattr(module, "_skip_diagonal_blocks", True) + ) + + needs_grad = torch.is_grad_enabled() and ( + query.requires_grad or key.requires_grad or value.requires_grad + ) + + use_packed = attention_mask is not None and _attention_mask_supported_for_triton(attention_mask) + if use_packed: + q_packed, k_packed, v_packed, b_start_loc, b_seq_len, max_input_len = ( + _derive_seq_lens_and_pack(query, key, value, attention_mask) + ) + fwd_kwargs = { + "b_start_loc": b_start_loc, + "b_seq_len": b_seq_len, + "max_input_len": max_input_len, + "is_causal": is_causal, + "softmax_scale": scaling, + "apply_sparse24": apply_sparse24, + "skip_diagonal_blocks": skip_diagonal_blocks, + } + if needs_grad: + o_packed = context_attention(q_packed, k_packed, v_packed, **fwd_kwargs) + else: + o_packed = torch.empty_like(q_packed) + context_attention_fwd(q_packed, k_packed, v_packed, o_packed, **fwd_kwargs) + attn_output = _unpack_attn_output( + o_packed, + batch, + num_heads, + head_dim, + seq_len, + b_seq_len, + query.dtype, + device, + ) + return (attn_output, None) + if attention_mask is not None: + raise ValueError( + f"Unsupported attention_mask format for modelopt_triton: " + f"dim={attention_mask.dim()}, shape={attention_mask.shape}. " + f"Only 2D [batch, seq_len] masks are supported." + ) + + q = query.permute(0, 2, 1, 3).reshape(-1, num_heads, head_dim).contiguous() + k = key.permute(0, 2, 1, 3).reshape(-1, num_kv_heads, head_dim).contiguous() + v = value.permute(0, 2, 1, 3).reshape(-1, num_kv_heads, head_dim).contiguous() + b_start_loc_q = torch.arange(batch, device=device, dtype=torch.int32) * seq_len + b_seq_len_q = torch.full((batch,), seq_len, device=device, dtype=torch.int32) + + if is_cross_attention: + b_start_loc_k = torch.arange(batch, device=device, dtype=torch.int32) * seq_k + b_seq_len_k = torch.full((batch,), seq_k, device=device, dtype=torch.int32) + else: + b_start_loc_k = None + b_seq_len_k = None + + fwd_kwargs = { + "b_start_loc": b_start_loc_q, + "b_seq_len": b_seq_len_q, + "max_input_len": seq_len, + "is_causal": is_causal, + "softmax_scale": scaling, + "apply_sparse24": apply_sparse24, + "skip_diagonal_blocks": skip_diagonal_blocks, + "b_start_loc_k": b_start_loc_k, + "b_seq_len_k": b_seq_len_k, + "max_input_len_k": seq_k if is_cross_attention else None, + } + if needs_grad: + o = context_attention(q, k, v, **fwd_kwargs) + else: + o = torch.empty_like(q) + context_attention_fwd(q, k, v, o, **fwd_kwargs) + attn_output = o.view(batch, seq_len, num_heads, head_dim) + return (attn_output, None) + + +def register_triton_attention() -> bool: + """Register the Triton backend with HF AttentionInterface. + + Call after importing this module so that attn_implementation="modelopt_triton" + is available when loading models. + + Returns: + True if registration succeeded. + """ + try: + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + ALL_ATTENTION_FUNCTIONS.register("modelopt_triton", triton_attention_forward) + return True + except Exception: + return False + + +def set_sparse24( + model: nn.Module, + apply_sparse24: bool = True, + skip_diagonal_blocks: bool = True, +) -> None: + """Set 2:4 sparse attention on all attention modules in the model. + + Prefer using ``mtsa.sparsify(model, SPARSE24_TRITON)`` from + ``modelopt.torch.sparsity.attention_sparsity`` for config-driven setup, + pattern-based layer selection, and consistency with other sparse methods. + This helper remains for backward compatibility and one-off scripting. + + The Triton backend reads ``getattr(module, '_apply_sparse24', False)`` and + ``getattr(module, '_skip_diagonal_blocks', True)`` when kwargs don't provide them. + + Limitations: + - **Prefill-only sparsity:** 2:4 sparsity is applied during prefill only; + decode uses the unified kernel without sparsity. + - **Fixed 50% sparsity:** 2:4 keeps top 2 of every 4 attention scores; + no threshold tuning or calibration. + - **Mutually exclusive with flash_skip_softmax:** sparse24 requires + ``attn_implementation="modelopt_triton"``; flash_skip_softmax requires + ``attn_implementation="eager"``. They cannot be combined in one model. + + Args: + model: Hugging Face model (e.g. LlamaForCausalLM). + apply_sparse24: Whether to apply 2:4 sparsity to attention scores. + skip_diagonal_blocks: If True, keep diagonal tiles dense (local attention). + """ + for _, module in model.named_modules(): + # Match only actual attention modules (have o_proj + head_dim), not their children + # like q_proj, k_proj, v_proj, rotary_emb, etc. + if hasattr(module, "o_proj") and hasattr(module, "head_dim"): + setattr(module, "_apply_sparse24", apply_sparse24) + setattr(module, "_skip_diagonal_blocks", skip_diagonal_blocks) + + +__all__ = [ + "register_triton_attention", + "set_sparse24", + "triton_attention_forward", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/triton_unified_attention.py b/modelopt/torch/sparsity/attention_sparsity/kernels/triton_unified_attention.py new file mode 100644 index 000000000..1e42ae41a --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/triton_unified_attention.py @@ -0,0 +1,1011 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Prefill kernel adapted from context_flashattention_nopad in SGLang / LightLLM. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Triton attention kernel for prefill and decode on flat packed tensors. + +Supports variable sequence lengths, causal/non-causal masking, GQA, 2:4 sparse attention, +and autograd-compatible forward/backward. +""" + +import torch +import triton +import triton.language as tl + +LOG2E: float = 1.44269504088896 + + +# --------------------------------------------------------------------------- +# 2:4 structured sparsity helpers +# --------------------------------------------------------------------------- +@triton.jit +def _sparse24_noabs_ops(x0, x1, x2, x3): + """Compute 2:4 sparsity mask: for every 4 values, determine which 2 are largest.""" + (a1, a2, a3, a4, a5, a6) = ( + x0 > x1, + x0 > x2, + x0 > x3, + x1 > x2, + x1 > x3, + x2 > x3, + ) + na1 = a1 == 0 + na2 = a2 == 0 + na3 = a3 == 0 + na4 = a4 == 0 + na5 = a5 == 0 + na6 = a6 == 0 + m0 = a2 & a3 | a1 & a2 | a1 & a3 + m1 = na1 & a5 | a4 & a5 | na1 & a4 + m2 = na2 & na4 | na2 & a6 | na4 & a6 + m3 = na3 & na5 | na3 & na6 | na5 & na6 + return x0, x1, x2, x3, m0, m1, m2, m3 + + +@triton.jit +def _apply_sparse24_to_qk_tile( + qk, + M: tl.constexpr, + N: tl.constexpr, + MASK_VAL: tl.constexpr, +): + """Apply 2:4 sparsity to attention score tile [M, N]: keep top 2 of every 4 along N.""" + reshaped = tl.reshape(qk, (M, N // 4, 4)) + cols = tl.arange(0, 4)[None, None, :] + x0 = tl.sum(tl.where(cols == 0, reshaped, 0.0), axis=2) + x1 = tl.sum(tl.where(cols == 1, reshaped, 0.0), axis=2) + x2 = tl.sum(tl.where(cols == 2, reshaped, 0.0), axis=2) + x3 = tl.sum(tl.where(cols == 3, reshaped, 0.0), axis=2) + _, _, _, _, m0, m1, m2, m3 = _sparse24_noabs_ops(x0, x1, x2, x3) + s0 = tl.where(m0, x0, MASK_VAL) + s1 = tl.where(m1, x1, MASK_VAL) + s2 = tl.where(m2, x2, MASK_VAL) + s3 = tl.where(m3, x3, MASK_VAL) + sparse_reshaped = tl.full((M, N // 4, 4), 0.0, dtype=qk.dtype) + sparse_reshaped = tl.where((cols == 0), tl.expand_dims(s0, 2), sparse_reshaped) + sparse_reshaped = tl.where((cols == 1), tl.expand_dims(s1, 2), sparse_reshaped) + sparse_reshaped = tl.where((cols == 2), tl.expand_dims(s2, 2), sparse_reshaped) + sparse_reshaped = tl.where((cols == 3), tl.expand_dims(s3, 2), sparse_reshaped) + sparse_qk = tl.reshape(sparse_reshaped, (M, N)) + return sparse_qk + + +# --------------------------------------------------------------------------- +# Shared: recompute masked S tile (used by forward inner loop and backward) +# --------------------------------------------------------------------------- +@triton.jit +def _mask_and_sparsify( + qk, + offs_m, + offs_n, + cur_batch_seq_len, + cur_batch_kv_len, + start_n, + start_m, + IS_CAUSAL: tl.constexpr, + APPLY_SPARSE24: tl.constexpr, + SKIP_DIAGONAL_BLOCKS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_FWD: tl.constexpr, +): + """Apply causal mask, padding mask, and optional 2:4 sparsity to a QK tile. + + BLOCK_FWD is the forward kernel's block size, used for diagonal detection so that + backward kernels (which may use smaller tiles) produce the same sparsity pattern. + """ + if IS_CAUSAL: + qk += tl.where( + (start_n + offs_n[None, :] < cur_batch_kv_len) + & (offs_m[:, None] >= (start_n + offs_n[None, :])), + 0, + float("-inf"), + ) + else: + qk += tl.where((start_n + offs_n[None, :]) < cur_batch_kv_len, 0, float("-inf")) + + if APPLY_SPARSE24: + if IS_CAUSAL and SKIP_DIAGONAL_BLOCKS: + # Diagonal detection at BLOCK_FWD granularity (matches forward) + q_pos_min = start_m * BLOCK_M + fwd_q_tile_start = (q_pos_min // BLOCK_FWD) * BLOCK_FWD + fwd_q_tile_end = fwd_q_tile_start + BLOCK_FWD + fwd_k_tile_start = (start_n // BLOCK_FWD) * BLOCK_FWD + fwd_k_tile_end = fwd_k_tile_start + BLOCK_FWD + is_diagonal = (fwd_k_tile_start < fwd_q_tile_end) & (fwd_k_tile_end > fwd_q_tile_start) + if not is_diagonal: + qk = _apply_sparse24_to_qk_tile(qk, BLOCK_M, BLOCK_N, float("-inf")) + else: + qk = _apply_sparse24_to_qk_tile(qk, BLOCK_M, BLOCK_N, float("-inf")) + return qk + + +# --------------------------------------------------------------------------- +# Forward kernel +# --------------------------------------------------------------------------- +@triton.jit +def _fwd_kernel_prefill( + Q, + K, + V, + qk_scale, + B_Start_Loc, + B_Seqlen, + B_Start_Loc_K, + B_Seqlen_K, + Out, + Lse, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_obs, + stride_oh, + stride_lse_tok, + stride_lse_head, + kv_group_num: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + Lk: tl.constexpr, + APPLY_SPARSE24: tl.constexpr, + SKIP_DIAGONAL_BLOCKS: tl.constexpr, + STORE_LSE: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_kv_len = tl.load(B_Seqlen_K + cur_batch) + cur_batch_q_start = tl.load(B_Start_Loc + cur_batch) + cur_batch_kv_start = tl.load(B_Start_Loc_K + cur_batch) + + block_start_loc = BLOCK_M * start_m + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_d = offs_d < Lk + + off_q = ( + (cur_batch_q_start + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] + ) + q = tl.load(Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & mask_d[None, :], other=0.0) + + off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] + off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] + k_ptrs = K + off_k + v_ptrs = V + off_v + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + end_n = ( + cur_batch_kv_len if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_kv_len) + ) + + for start_n in range(0, block_mask * end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k = tl.load( + k_ptrs + (cur_batch_kv_start + start_n) * stride_kbs, + mask=((start_n + offs_n[None, :]) < cur_batch_kv_len) & mask_d[:, None], + other=0.0, + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= qk_scale + qk = _mask_and_sparsify( + qk, + offs_m, + offs_n, + cur_batch_seq_len, + cur_batch_kv_len, + start_n, + start_m, + IS_CAUSAL, + APPLY_SPARSE24, + SKIP_DIAGONAL_BLOCKS, + BLOCK_M, + BLOCK_N, + BLOCK_M, # BLOCK_FWD = BLOCK_M in forward + ) + + # deferred-normalization online softmax (exp2) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + v = tl.load( + v_ptrs + (cur_batch_kv_start + start_n) * stride_vbs, + mask=((start_n + offs_n[:, None]) < cur_batch_kv_len) & mask_d[None, :], + other=0.0, + ) + acc = tl.dot(p.to(v.dtype), v, acc) + m_i = m_ij + + acc = acc / l_i[:, None] + + if STORE_LSE: + lse_i = m_i + tl.math.log2(l_i) + lse_i = tl.where(l_i == 0.0, float("-inf"), lse_i) + off_lse = (cur_batch_q_start + offs_m) * stride_lse_tok + cur_head * stride_lse_head + tl.store(Lse + off_lse, lse_i, mask=offs_m < cur_batch_seq_len) + + off_o = ( + (cur_batch_q_start + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] + ) + tl.store(Out + off_o, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & mask_d[None, :]) + + +# --------------------------------------------------------------------------- +# Backward kernels +# --------------------------------------------------------------------------- +@triton.jit +def _bwd_preprocess( + Out, + dO, + Delta, + stride_obs, + stride_oh, + stride_dobs, + stride_doh, + stride_delta_tok, + stride_delta_head, + total_tokens, + Lk: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, +): + """Compute D_i = rowsum(O_i * dO_i) per query position per head.""" + head = tl.program_id(0) + offs_tok = tl.program_id(1) * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_DMODEL) + mask_tok = offs_tok < total_tokens + mask_d = offs_d < Lk + o = tl.load( + Out + offs_tok[:, None] * stride_obs + head * stride_oh + offs_d[None, :], + mask=mask_tok[:, None] & mask_d[None, :], + other=0.0, + ) + do = tl.load( + dO + offs_tok[:, None] * stride_dobs + head * stride_doh + offs_d[None, :], + mask=mask_tok[:, None] & mask_d[None, :], + other=0.0, + ) + delta = tl.sum(o * do, axis=1) + tl.store(Delta + offs_tok * stride_delta_tok + head * stride_delta_head, delta, mask=mask_tok) + + +@triton.jit +def _bwd_kernel_dq( + Q, + K, + V, + dO, + dQ, + Lse, + Delta, + B_Start_Loc, + B_Seqlen, + B_Start_Loc_K, + B_Seqlen_K, + qk_scale, + sm_scale, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_dobs, + stride_doh, + stride_dqbs, + stride_dqh, + stride_lse_tok, + stride_lse_head, + kv_group_num: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + Lk: tl.constexpr, + APPLY_SPARSE24: tl.constexpr, + SKIP_DIAGONAL_BLOCKS: tl.constexpr, + BLOCK_FWD: tl.constexpr, +): + """Backward: compute dQ for one Q tile, looping over KV tiles.""" + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_kv_len = tl.load(B_Seqlen_K + cur_batch) + cur_batch_q_start = tl.load(B_Start_Loc + cur_batch) + cur_batch_kv_start = tl.load(B_Start_Loc_K + cur_batch) + + if start_m * BLOCK_M >= cur_batch_seq_len: + return + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_DMODEL) + mask_d = offs_d < Lk + mask_m = offs_m < cur_batch_seq_len + + # Load Q, dO — stay in registers + off_qm = ( + (cur_batch_q_start + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] + ) + q = tl.load(Q + off_qm, mask=mask_m[:, None] & mask_d[None, :], other=0.0) + off_dom = ( + (cur_batch_q_start + offs_m[:, None]) * stride_dobs + + cur_head * stride_doh + + offs_d[None, :] + ) + do = tl.load(dO + off_dom, mask=mask_m[:, None] & mask_d[None, :], other=0.0) + + off_lse = (cur_batch_q_start + offs_m) * stride_lse_tok + cur_head * stride_lse_head + lse = tl.load(Lse + off_lse, mask=mask_m, other=0.0) + delta = tl.load(Delta + off_lse, mask=mask_m, other=0.0) + + dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + end_n = ( + cur_batch_kv_len if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_kv_len) + ) + offs_n = tl.arange(0, BLOCK_N) + + for start_n in range(0, end_n, BLOCK_N): + mask_n = (start_n + offs_n) < cur_batch_kv_len + # Load K^T [BLOCK_DMODEL, BLOCK_N] and V [BLOCK_N, BLOCK_DMODEL] + off_kn = ( + (cur_batch_kv_start + start_n + offs_n[None, :]) * stride_kbs + + cur_kv_head * stride_kh + + offs_d[:, None] + ) + kT = tl.load(K + off_kn, mask=mask_n[None, :] & mask_d[:, None], other=0.0) + off_vn = ( + (cur_batch_kv_start + start_n + offs_n[:, None]) * stride_vbs + + cur_kv_head * stride_vh + + offs_d[None, :] + ) + v = tl.load(V + off_vn, mask=mask_n[:, None] & mask_d[None, :], other=0.0) + + # Recompute S [BLOCK_M, BLOCK_N] + s = tl.dot(q, kT) + s *= qk_scale + s = _mask_and_sparsify( + s, + offs_m, + offs_n, + cur_batch_seq_len, + cur_batch_kv_len, + start_n, + start_m, + IS_CAUSAL, + APPLY_SPARSE24, + SKIP_DIAGONAL_BLOCKS, + BLOCK_M, + BLOCK_N, + BLOCK_FWD, + ) + p = tl.math.exp2(s - lse[:, None]) + + # dP = dO @ V^T [BLOCK_M, BLOCK_N] + dp = tl.dot(do, tl.trans(v)) + # dS = P * (dP - delta) + ds = p * (dp - delta[:, None]) + # dQ += dS @ K (= dS @ kT^T) + dq += tl.dot(ds.to(kT.dtype), tl.trans(kT)) + + dq *= sm_scale + off_dqm = ( + (cur_batch_q_start + offs_m[:, None]) * stride_dqbs + + cur_head * stride_dqh + + offs_d[None, :] + ) + tl.store(dQ + off_dqm, dq.to(q.dtype), mask=mask_m[:, None] & mask_d[None, :]) + + +@triton.jit +def _bwd_kernel_dkdv( + Q, + K, + V, + dO, + dK, + dV, + Lse, + Delta, + B_Start_Loc, + B_Seqlen, + B_Start_Loc_K, + B_Seqlen_K, + qk_scale, + sm_scale, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_dobs, + stride_doh, + stride_dkbs, + stride_dkh, + stride_dvbs, + stride_dvh, + stride_lse_tok, + stride_lse_head, + kv_group_num: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + Lk: tl.constexpr, + APPLY_SPARSE24: tl.constexpr, + SKIP_DIAGONAL_BLOCKS: tl.constexpr, + BLOCK_FWD: tl.constexpr, +): + """Backward: compute dK, dV for one KV tile, looping over Q tiles and GQA heads.""" + cur_batch = tl.program_id(0) + cur_kv_head = tl.program_id(1) + start_n = tl.program_id(2) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_kv_len = tl.load(B_Seqlen_K + cur_batch) + cur_batch_q_start = tl.load(B_Start_Loc + cur_batch) + cur_batch_kv_start = tl.load(B_Start_Loc_K + cur_batch) + + kv_block_start = start_n * BLOCK_N + if kv_block_start >= cur_batch_kv_len: + return + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + mask_d = offs_d < Lk + mask_n = (kv_block_start + offs_n) < cur_batch_kv_len + + # Load K, V tiles [BLOCK_N, BLOCK_DMODEL] — stay in SRAM + abs_offs_n = kv_block_start + offs_n + off_kn = ( + (cur_batch_kv_start + abs_offs_n[:, None]) * stride_kbs + + cur_kv_head * stride_kh + + offs_d[None, :] + ) + off_vn = ( + (cur_batch_kv_start + abs_offs_n[:, None]) * stride_vbs + + cur_kv_head * stride_vh + + offs_d[None, :] + ) + k_tile = tl.load(K + off_kn, mask=mask_n[:, None] & mask_d[None, :], other=0.0) + v_tile = tl.load(V + off_vn, mask=mask_n[:, None] & mask_d[None, :], other=0.0) + kT_tile = tl.trans(k_tile) # [BLOCK_DMODEL, BLOCK_N] + + dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + + num_q_tiles = (cur_batch_seq_len + BLOCK_M - 1) // BLOCK_M + q_tile_start = kv_block_start // BLOCK_M if IS_CAUSAL else 0 + offs_m_base = tl.arange(0, BLOCK_M) + + for q_tile_idx in range(q_tile_start, num_q_tiles): + offs_m = q_tile_idx * BLOCK_M + offs_m_base + mask_m = offs_m < cur_batch_seq_len + + for gqa_idx in range(kv_group_num): + cur_head = cur_kv_head * kv_group_num + gqa_idx + # Load Q, dO [BLOCK_M, BLOCK_DMODEL] + off_qm = ( + (cur_batch_q_start + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] + ) + q_tile = tl.load(Q + off_qm, mask=mask_m[:, None] & mask_d[None, :], other=0.0) + off_dom = ( + (cur_batch_q_start + offs_m[:, None]) * stride_dobs + + cur_head * stride_doh + + offs_d[None, :] + ) + do_tile = tl.load(dO + off_dom, mask=mask_m[:, None] & mask_d[None, :], other=0.0) + # Load lse, delta [BLOCK_M] + off_lse = (cur_batch_q_start + offs_m) * stride_lse_tok + cur_head * stride_lse_head + lse = tl.load(Lse + off_lse, mask=mask_m, other=0.0) + delta_val = tl.load(Delta + off_lse, mask=mask_m, other=0.0) + + # Recompute S [BLOCK_M, BLOCK_N] in ORIGINAL orientation + s = tl.dot(q_tile, kT_tile) + s *= qk_scale + s = _mask_and_sparsify( + s, + offs_m, + offs_n, + cur_batch_seq_len, + cur_batch_kv_len, + start_n * BLOCK_N, + q_tile_idx, + IS_CAUSAL, + APPLY_SPARSE24, + SKIP_DIAGONAL_BLOCKS, + BLOCK_M, + BLOCK_N, + BLOCK_FWD, + ) + p = tl.math.exp2(s - lse[:, None]) + + # dV += P^T @ dO + dv += tl.dot(tl.trans(p.to(do_tile.dtype)), do_tile) + # dP = dO @ V^T + dp = tl.dot(do_tile, tl.trans(v_tile)) + # dS = P * (dP - delta) + ds = p * (dp - delta_val[:, None]) + # dK += dS^T @ Q + dk += tl.dot(tl.trans(ds.to(q_tile.dtype)), q_tile) + + dk *= sm_scale + tl.store(dK + off_kn, dk.to(k_tile.dtype), mask=mask_n[:, None] & mask_d[None, :]) + tl.store(dV + off_vn, dv.to(v_tile.dtype), mask=mask_n[:, None] & mask_d[None, :]) + + +# --------------------------------------------------------------------------- +# Python helpers +# --------------------------------------------------------------------------- +def _prepare_fwd_args( + q, + k, + v, + b_start_loc, + b_seq_len, + max_input_len, + softmax_scale, + b_start_loc_k, + b_seq_len_k, + max_input_len_k, + apply_sparse24, +): + """Validate inputs and derive common parameters.""" + if q.dim() != 3 or k.dim() != 3 or v.dim() != 3: + raise ValueError( + "q, k, v must be rank-3 [total_tokens, num_heads, head_dim]; " + f"got q.dim()={q.dim()}, k.dim()={k.dim()}, v.dim()={v.dim()}." + ) + head_dim = q.shape[2] + num_kv_heads = k.shape[1] + if num_kv_heads <= 0: + raise ValueError(f"k.shape[1] (num_kv_heads) must be positive; got {num_kv_heads}.") + if q.shape[1] % num_kv_heads != 0: + raise ValueError( + f"num_heads must be divisible by num_kv_heads; got {q.shape[1]} and {num_kv_heads}." + ) + if b_seq_len_k is None: + total_q = q.shape[0] + if k.shape[0] != total_q or v.shape[0] != total_q: + raise ValueError( + "For self-attention, q, k, v must have same shape[0]; " + f"got {q.shape[0]}, {k.shape[0]}, {v.shape[0]}." + ) + b_seq_len_k = b_seq_len + b_start_loc_k = b_start_loc + max_input_len_k = max_input_len + + batch = b_seq_len.shape[0] + if b_start_loc_k is None: + b_start_loc_k = torch.zeros(batch + 1, device=q.device, dtype=torch.int32) + b_start_loc_k[1:] = torch.cumsum(b_seq_len_k.to(torch.int64), dim=0) + b_start_loc_k = b_start_loc_k[:batch] + if max_input_len_k is None: + max_input_len_k = int(b_seq_len_k.max().item()) + + Lk = head_dim + num_q_heads = q.shape[1] + kv_group_num = num_q_heads // num_kv_heads + sm_scale = 1.0 / (Lk**0.5) if softmax_scale is None else softmax_scale + qk_scale = sm_scale * LOG2E + + capability = torch.cuda.get_device_capability() + BLOCK_FWD = 128 if capability[0] >= 8 else 64 + BLOCK_BWD = 64 # backward holds more tiles in SRAM; smaller block avoids shared memory overflow + if apply_sparse24 and BLOCK_BWD % 4 != 0: + raise ValueError(f"sparse24 requires BLOCK divisible by 4, got {BLOCK_BWD}") + num_warps_fwd = 4 if Lk <= 64 else 8 + num_warps_bwd = 4 # fewer warps to reduce shared memory pressure in backward + + return ( + b_start_loc_k, + b_seq_len_k, + max_input_len_k, + sm_scale, + qk_scale, + BLOCK_FWD, + BLOCK_BWD, + num_warps_fwd, + num_warps_bwd, + Lk, + kv_group_num, + num_q_heads, + num_kv_heads, + batch, + ) + + +# --------------------------------------------------------------------------- +# Autograd wrapper +# --------------------------------------------------------------------------- +class _ContextAttentionFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + b_start_loc, + b_seq_len, + max_input_len, + is_causal, + softmax_scale, + apply_sparse24, + skip_diagonal_blocks, + b_start_loc_k, + b_seq_len_k, + max_input_len_k, + ): + ( + b_start_loc_k, + b_seq_len_k, + max_input_len_k, + sm_scale, + qk_scale, + BLOCK_FWD, + BLOCK_BWD, + num_warps_fwd, + num_warps_bwd, + Lk, + kv_group_num, + num_q_heads, + num_kv_heads, + batch, + ) = _prepare_fwd_args( + q, + k, + v, + b_start_loc, + b_seq_len, + max_input_len, + softmax_scale, + b_start_loc_k, + b_seq_len_k, + max_input_len_k, + apply_sparse24, + ) + + o = torch.empty_like(q) + lse = torch.empty(q.shape[0], num_q_heads, device=q.device, dtype=torch.float32) + BLOCK_DMODEL = triton.next_power_of_2(Lk) + grid = (batch, num_q_heads, triton.cdiv(max_input_len, BLOCK_FWD)) + + _fwd_kernel_prefill[grid]( + q, + k, + v, + qk_scale, + b_start_loc, + b_seq_len, + b_start_loc_k, + b_seq_len_k, + o, + lse, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + lse.stride(0), + lse.stride(1), + kv_group_num=kv_group_num, + BLOCK_M=BLOCK_FWD, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_N=BLOCK_FWD, + IS_CAUSAL=is_causal, + Lk=Lk, + APPLY_SPARSE24=apply_sparse24, + SKIP_DIAGONAL_BLOCKS=skip_diagonal_blocks, + STORE_LSE=True, + num_warps=num_warps_fwd, + num_stages=1, + ) + + ctx.save_for_backward(q, k, v, o, lse, b_start_loc, b_seq_len, b_start_loc_k, b_seq_len_k) + ctx.max_input_len = max_input_len + ctx.max_input_len_k = max_input_len_k + ctx.sm_scale = sm_scale + ctx.qk_scale = qk_scale + ctx.is_causal = is_causal + ctx.apply_sparse24 = apply_sparse24 + ctx.skip_diagonal_blocks = skip_diagonal_blocks + ctx.BLOCK_FWD = BLOCK_FWD + ctx.BLOCK_BWD = BLOCK_BWD + ctx.num_warps_bwd = num_warps_bwd + ctx.Lk = Lk + ctx.kv_group_num = kv_group_num + ctx.num_q_heads = num_q_heads + ctx.num_kv_heads = num_kv_heads + ctx.batch = batch + return o + + @staticmethod + def backward(ctx, grad_output): + q, k, v, o, lse, b_start_loc, b_seq_len, b_start_loc_k, b_seq_len_k = ctx.saved_tensors + BLOCK = ctx.BLOCK_BWD + Lk = ctx.Lk + BLOCK_DMODEL = triton.next_power_of_2(Lk) + do = grad_output.contiguous() + num_warps = ctx.num_warps_bwd + + # Phase 1: delta = rowsum(O * dO) + delta = torch.empty_like(lse) + pre_grid = (ctx.num_q_heads, triton.cdiv(q.shape[0], BLOCK)) + _bwd_preprocess[pre_grid]( + o, + do, + delta, + o.stride(0), + o.stride(1), + do.stride(0), + do.stride(1), + delta.stride(0), + delta.stride(1), + q.shape[0], + Lk=Lk, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_M=BLOCK, + ) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + # Phase 2: dK, dV + grid_dkdv = (ctx.batch, ctx.num_kv_heads, triton.cdiv(ctx.max_input_len_k, BLOCK)) + _bwd_kernel_dkdv[grid_dkdv]( + q, + k, + v, + do, + dk, + dv, + lse, + delta, + b_start_loc, + b_seq_len, + b_start_loc_k, + b_seq_len_k, + ctx.qk_scale, + ctx.sm_scale, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + do.stride(0), + do.stride(1), + dk.stride(0), + dk.stride(1), + dv.stride(0), + dv.stride(1), + lse.stride(0), + lse.stride(1), + kv_group_num=ctx.kv_group_num, + BLOCK_M=BLOCK, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_N=BLOCK, + IS_CAUSAL=ctx.is_causal, + Lk=Lk, + APPLY_SPARSE24=ctx.apply_sparse24, + SKIP_DIAGONAL_BLOCKS=ctx.skip_diagonal_blocks, + BLOCK_FWD=ctx.BLOCK_FWD, + num_warps=num_warps, + num_stages=1, + ) + + # Phase 3: dQ + grid_dq = (ctx.batch, ctx.num_q_heads, triton.cdiv(ctx.max_input_len, BLOCK)) + _bwd_kernel_dq[grid_dq]( + q, + k, + v, + do, + dq, + lse, + delta, + b_start_loc, + b_seq_len, + b_start_loc_k, + b_seq_len_k, + ctx.qk_scale, + ctx.sm_scale, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + do.stride(0), + do.stride(1), + dq.stride(0), + dq.stride(1), + lse.stride(0), + lse.stride(1), + kv_group_num=ctx.kv_group_num, + BLOCK_M=BLOCK, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_N=BLOCK, + IS_CAUSAL=ctx.is_causal, + Lk=Lk, + APPLY_SPARSE24=ctx.apply_sparse24, + SKIP_DIAGONAL_BLOCKS=ctx.skip_diagonal_blocks, + BLOCK_FWD=ctx.BLOCK_FWD, + num_warps=num_warps, + num_stages=1, + ) + + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- +def context_attention_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + b_start_loc: torch.Tensor, + b_seq_len: torch.Tensor, + max_input_len: int, + is_causal: bool = True, + softmax_scale: float | None = None, + apply_sparse24: bool = False, + skip_diagonal_blocks: bool = True, + b_start_loc_k: torch.Tensor | None = None, + b_seq_len_k: torch.Tensor | None = None, + max_input_len_k: int | None = None, +) -> None: + """Inference-only attention (no backward). Writes output to ``o`` in-place.""" + if o.shape[0] != q.shape[0] or o.shape[1] != q.shape[1] or o.shape[2] != q.shape[2]: + raise ValueError(f"o must match q shape; got o={o.shape}, q={q.shape}.") + + ( + b_start_loc_k, + b_seq_len_k, + max_input_len_k, + sm_scale, + qk_scale, + BLOCK_FWD, + _BLOCK_BWD, + num_warps_fwd, + _num_warps_bwd, + Lk, + kv_group_num, + num_q_heads, + num_kv_heads, + batch, + ) = _prepare_fwd_args( + q, + k, + v, + b_start_loc, + b_seq_len, + max_input_len, + softmax_scale, + b_start_loc_k, + b_seq_len_k, + max_input_len_k, + apply_sparse24, + ) + + grid = (batch, num_q_heads, triton.cdiv(max_input_len, BLOCK_FWD)) + _fwd_kernel_prefill[grid]( + q, + k, + v, + qk_scale, + b_start_loc, + b_seq_len, + b_start_loc_k, + b_seq_len_k, + o, + None, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + 0, + 0, + kv_group_num=kv_group_num, + BLOCK_M=BLOCK_FWD, + BLOCK_DMODEL=triton.next_power_of_2(Lk), + BLOCK_N=BLOCK_FWD, + IS_CAUSAL=is_causal, + Lk=Lk, + APPLY_SPARSE24=apply_sparse24, + SKIP_DIAGONAL_BLOCKS=skip_diagonal_blocks, + STORE_LSE=False, + num_warps=num_warps_fwd, + num_stages=1, + ) + + +def context_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + b_start_loc: torch.Tensor, + b_seq_len: torch.Tensor, + max_input_len: int, + is_causal: bool = True, + softmax_scale: float | None = None, + apply_sparse24: bool = False, + skip_diagonal_blocks: bool = True, + b_start_loc_k: torch.Tensor | None = None, + b_seq_len_k: torch.Tensor | None = None, + max_input_len_k: int | None = None, +) -> torch.Tensor: + """Attention with autograd support (training). Returns output tensor with grad_fn.""" + return _ContextAttentionFunc.apply( + q, + k, + v, + b_start_loc, + b_seq_len, + max_input_len, + is_causal, + softmax_scale, + apply_sparse24, + skip_diagonal_blocks, + b_start_loc_k, + b_seq_len_k, + max_input_len_k, + ) + + +__all__ = ["context_attention", "context_attention_fwd"] diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py index 8a109fda7..21c6f4312 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py @@ -24,4 +24,4 @@ ] # Import method implementations to trigger registration -from . import flash_skip_softmax +from . import flash_skip_softmax, sparse24_triton diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index e575de4da..3ebfd8c7e 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -24,6 +24,9 @@ import numpy as np import torch +import torch.nn.functional as F + +from modelopt.torch.quantization.utils import replace_function from . import SparseAttentionMethod, register_sparse_method @@ -353,6 +356,19 @@ def get_threshold_info(self) -> dict[str, Any]: "value": self.threshold_config, } + def get_sparse_context(self, module: torch.nn.Module): + """Return a context manager that patches F.softmax with sparse masking.""" + original_softmax = F.softmax + + def sparse_softmax(input, dim=-1, *args, **kwargs): + sparse_mask, stats = self.calculate_sparsity(input) + module._last_stats = stats + if not self._calibration_mode: + input = self.apply_sparsity(input, sparse_mask) + return original_softmax(input, dim, *args, **kwargs) + + return replace_function(torch.nn.functional, "softmax", sparse_softmax) + @property def name(self) -> str: """Method identifier.""" diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index 6329e4446..3f3e78db6 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -70,6 +70,18 @@ def apply_sparsity( Masked attention scores with sparse elements set to -inf """ + def get_sparse_context(self, module: torch.nn.Module): + """Return a context manager that activates this method's sparsity during forward. + + Each method subclass implements its own activation mechanism: + - Softmax-patching methods replace F.softmax during the forward pass. + - Kernel-fused methods set flags on ``module`` that the kernel reads. + + Args: + module: The SparseAttentionModule wrapping the attention layer. + """ + raise NotImplementedError(f"{type(self).__name__} must implement get_sparse_context()") + def get_threshold_info(self) -> dict[str, Any]: """Get threshold information for display/debugging. diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/sparse24_triton.py b/modelopt/torch/sparsity/attention_sparsity/methods/sparse24_triton.py new file mode 100644 index 000000000..ed32d40bd --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/sparse24_triton.py @@ -0,0 +1,161 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""2:4 structured sparse attention method for the Triton prefill kernel. + +This method is used with backend="triton" and attn_implementation="modelopt_triton". +Sparsity is applied inside the Triton kernel during prefill; this class provides +the SparseAttentionMethod interface for config-driven setup and optional diagnostics. +""" + +import contextlib +from typing import Any + +import torch + +from . import SparseAttentionMethod, register_sparse_method + + +def _sparse24_mask_along_last_dim(scores: torch.Tensor) -> torch.Tensor: + """Compute 2:4 mask: for every 4 elements along the last dim, keep the 2 largest. + + Args: + scores: Tensor of shape [..., N] with N divisible by 4. + + Returns: + Boolean mask of same shape; True where the element is kept (top-2 of 4). + """ + *prefix, n = scores.shape + assert n % 4 == 0, "2:4 sparsity requires last dim divisible by 4" + grouped = scores.reshape(*prefix, n // 4, 4) + # topk(2) along dim=-1; indices [..., 0] and [..., 1] are the two largest + _, top2_idx = torch.topk(grouped, k=2, dim=-1, largest=True, sorted=False) + mask = torch.zeros_like(grouped, dtype=torch.bool) + mask.scatter_(-1, top2_idx, True) + return mask.reshape(*prefix, n) + + +@register_sparse_method("sparse24_triton") +class Sparse24Triton(SparseAttentionMethod): + """2:4 structured sparse attention for the Triton prefill kernel. + + When backend is "triton", sparsity is applied inside the kernel; this method + provides the config interface and optional PyTorch-side diagnostics (e.g. + calculate_sparsity for stats). No calibration; pattern is fixed (top-2 of every 4). + """ + + def __init__(self, method_config: dict | None = None): + """Initialize 2:4 Triton sparse attention method. + + Args: + method_config: Configuration dict. Uses skip_diagonal_blocks, is_causal; + ignores threshold, br, bc (not used by 2:4). + """ + super().__init__() + config = method_config or {} + self.skip_diagonal_blocks = config.get("skip_diagonal_blocks", True) + self.is_causal = config.get("is_causal", True) + self.backend = config.get("backend", "triton") + + def _infer_phase(self, attention_scores: torch.Tensor) -> str: + """Infer phase from attention scores shape.""" + return "decode" if attention_scores.shape[2] == 1 else "prefill" + + def calculate_sparsity( + self, + attention_scores: torch.Tensor, + ) -> tuple[torch.Tensor, dict]: + """Calculate 2:4 sparsity mask and statistics (PyTorch reference). + + Used for diagnostics when collect_stats is enabled. The actual sparsity + during forward with backend="triton" is applied inside the Triton kernel. + + Args: + attention_scores: [batch, heads, seq_q, seq_k] + + Returns: + (sparse_mask, stats_dict) + """ + assert attention_scores.dim() == 4, ( + f"Expected 4D attention scores, got shape {attention_scores.shape}" + ) + batch, num_heads, seq_q, seq_k = attention_scores.shape + phase = self._infer_phase(attention_scores) + + # Pad seq_k to multiple of 4 for 2:4 grouping + pad = (4 - seq_k % 4) % 4 + if pad > 0: + scores_padded = torch.nn.functional.pad( + attention_scores, (0, pad), value=torch.finfo(attention_scores.dtype).min + ) + else: + scores_padded = attention_scores + + mask_padded = _sparse24_mask_along_last_dim(scores_padded) + if pad > 0: + sparse_mask = mask_padded[..., :seq_k].contiguous() + else: + sparse_mask = mask_padded + + # 2:4 keeps 2 of 4 -> 50% kept (0.5 sparsity ratio as "fraction sparse" = 0.5) + sparsity = 0.5 + stats = { + "sparsity": sparsity, + "phase": phase, + "total_blocks": (seq_k + pad) // 4 * seq_q * num_heads * batch, + "sparse_blocks": int(0.5 * (seq_k + pad) // 4 * seq_q * num_heads * batch), + "sample_length": seq_k, + } + return sparse_mask, stats + + def apply_sparsity( + self, + attention_scores: torch.Tensor, + sparse_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Apply 2:4 sparsity mask to attention scores. + + Args: + attention_scores: [batch, heads, seq_q, seq_k] + sparse_mask: Optional pre-computed mask. If None, computes via calculate_sparsity. + + Returns: + Masked scores (same shape); masked positions set to dtype min. + """ + if sparse_mask is None: + sparse_mask, _ = self.calculate_sparsity(attention_scores) + mask_value = torch.finfo(attention_scores.dtype).min + return attention_scores.masked_fill(~sparse_mask, mask_value) + + @contextlib.contextmanager + def get_sparse_context(self, module: torch.nn.Module): + """Set _apply_sparse24 and _skip_diagonal_blocks on module for the Triton kernel.""" + module._apply_sparse24 = True + # Diagonal skip only applies to causal self-attention; for cross-attention + # there is no diagonal relationship between Q and K positions. + module._skip_diagonal_blocks = self.skip_diagonal_blocks and self.is_causal + try: + yield + finally: + module._apply_sparse24 = False + + def get_threshold_info(self) -> dict[str, Any]: + """Return fixed 2:4 pattern info.""" + return {"type": "fixed", "value": 0.5} + + @property + def name(self) -> str: + """Method identifier.""" + return "sparse24_triton" diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py index 828d126e8..90c473005 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -15,6 +15,7 @@ """Dynamic sparse attention registration for HuggingFace models.""" +import logging import warnings import torch.nn as nn @@ -25,6 +26,8 @@ from ..sparse_attention import SparseAttentionModule, SparseAttentionRegistry from . import CUSTOM_MODEL_PLUGINS +logger = logging.getLogger(__name__) + class _GenericSparseAttention(SparseAttentionModule): """Generic sparse attention that works with any HF attention module. @@ -93,10 +96,12 @@ def register_sparse_attention_on_the_fly(model: nn.Module) -> bool: SparseAttentionRegistry.register({module_type: type_name})(_GenericSparseAttention) attention_types.add(module_type) registered_count += 1 - print(f"Registered {type_name} for sparse attention optimization") + logger.info("Registered %s for sparse attention optimization", type_name) if registered_count > 0: - print(f"Dynamically registered {registered_count} attention module types for sparsity") + logger.info( + "Dynamically registered %d attention module types for sparsity", registered_count + ) return registered_count > 0 @@ -124,10 +129,12 @@ def _is_supported_model(model: nn.Module) -> bool: def validate_eager_attention(model: nn.Module) -> None: - """Validate and enforce eager attention for HuggingFace models. + """Validate attention implementation for HuggingFace models. - Sparse attention requires attn_implementation='eager' because it - patches torch.nn.functional.softmax, which is only called in eager mode. + For softmax-patching methods (e.g. flash_skip_softmax) the model must use + attn_implementation='eager'. For the Triton 2:4 kernel (sparse24_triton) + the model must use attn_implementation='modelopt_triton'. We only force + eager when the current implementation is neither eager nor modelopt_triton. Args: model: Model to validate @@ -136,10 +143,10 @@ def validate_eager_attention(model: nn.Module) -> None: return attn_impl = getattr(model.config, "_attn_implementation", None) - if attn_impl and attn_impl != "eager": + if attn_impl and attn_impl not in ("eager", "modelopt_triton"): warnings.warn( - f"Sparse attention requires attn_implementation='eager', but model uses '{attn_impl}'. " - "Forcing eager attention implementation." + f"Sparse attention expects attn_implementation='eager' or 'modelopt_triton', " + f"but model uses '{attn_impl}'. Forcing eager attention implementation." ) model.config._attn_implementation = "eager" diff --git a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py index 281e11e7d..17afeccde 100644 --- a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py +++ b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py @@ -17,11 +17,7 @@ from typing import Any -import torch -import torch.nn.functional as F - from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls -from modelopt.torch.quantization.utils import replace_function from .config import SparseAttentionAttributeConfig from .methods import get_sparse_method @@ -32,28 +28,23 @@ class SparseAttentionModule(DynamicModule): """Generic sparse attention module wrapper for applying sparsity to attention layers. This module wraps existing attention implementations to add sparse attention - capabilities by patching torch.nn.functional.softmax. + capabilities. The activation mechanism is delegated to the configured method + via ``method.get_sparse_context(module)``, so each method defines how it + integrates with the forward pass (e.g. softmax patching, kernel flags). Forward Flow: ------------- 1. Check if sparse attention is enabled (pass-through if disabled) - 2. Create softmax patch context with sparse_softmax function - 3. Apply sparse attention by patching F.softmax: - - Patches torch.nn.functional.softmax with sparse_softmax - - sparse_softmax applies method's sparsity logic before softmax - 4. Forward through original attention with sparsity applied - - Requirements: - ------------- - - Model must be loaded with attn_implementation="eager" for proper softmax interception - - Only PyTorch backend is supported (patches F.softmax) + 2. Obtain method-specific context via ``_sparse_method_instance.get_sparse_context(self)`` + 3. Run the original forward inside the context + 4. Collect statistics if stats manager is enabled Attributes: ----------- _enabled: bool Whether sparse attention is enabled _method: str - The sparse attention method to use (e.g., "flash_skip_softmax") + The sparse attention method to use (e.g., "flash_skip_softmax", "sparse24_triton") _method_config: dict Configuration dictionary for the sparse method (threshold, br, bc, etc.) _sparse_method_instance: SparseAttentionMethod @@ -190,32 +181,12 @@ def forward(self, *args, **kwargs): return result def _get_sparse_context(self): - """Get the softmax patch context for applying sparse attention.""" - return self._create_softmax_patch_context() - - def _create_softmax_patch_context(self): - """Create context manager for patching softmax function.""" - return replace_function(torch.nn.functional, "softmax", self._create_sparse_softmax()) - - def _create_sparse_softmax(self): - """Create sparse softmax function for current method.""" - original_softmax = F.softmax + """Get the context manager for applying sparse attention. - def sparse_softmax(input, dim=-1, *args, **kwargs): - # Calculate sparsity mask and collect statistics - sparse_mask, stats = self._sparse_method_instance.calculate_sparsity(input) - - # Store stats for collection - self._last_stats = stats - - # Only apply sparsity mask after calibration (not during calibration) - # During calibration, we measure sparsity without modifying the output - if not self._sparse_method_instance._calibration_mode: - input = self._sparse_method_instance.apply_sparsity(input, sparse_mask) - - return original_softmax(input, dim, *args, **kwargs) - - return sparse_softmax + Delegates to the method instance so each method defines its own + activation mechanism (softmax patching, kernel flags, etc.). + """ + return self._sparse_method_instance.get_sparse_context(self) # Create registry for sparse attention modules diff --git a/pyproject.toml b/pyproject.toml index bffa547b6..3324dcecb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,7 @@ extend-ignore = [ "E501", ] # Ignore missing docstrings or line length for Jupyter notebooks "modelopt/torch/quantization/triton/*" = ["N803", "N806", "E731"] # triton style +"modelopt/torch/sparsity/attention_sparsity/kernels/*" = ["N803", "N806"] # triton kernel style "examples/deepseek/ds_kernel.py" = ["N803", "N806", "E731"] # triton style [tool.ruff.lint.pycodestyle] diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_unified_attention.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_unified_attention.py new file mode 100644 index 000000000..c906337ec --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_unified_attention.py @@ -0,0 +1,862 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for Triton unified attention kernel.""" + +import pytest +import torch +import torch.nn.functional as F + +pytestmark = [ + pytest.mark.filterwarnings("ignore::UserWarning"), + pytest.mark.filterwarnings("ignore::RuntimeWarning"), + pytest.mark.filterwarnings("ignore::DeprecationWarning"), +] + +from modelopt.torch.sparsity.attention_sparsity.kernels import ( + IS_AVAILABLE as TRITON_KERNEL_AVAILABLE, +) + +if TRITON_KERNEL_AVAILABLE: + from modelopt.torch.sparsity.attention_sparsity.kernels import context_attention_fwd + + +def _sdpa_reference(q, k, v, b_start_loc, b_seq_len): + """SDPA causal reference. Supports GQA. Returns [total_tokens, num_heads, dim].""" + batch = b_seq_len.shape[0] + num_q, num_kv = q.shape[1], k.shape[1] + parts = [] + for b in range(batch): + s, n = int(b_start_loc[b].item()), int(b_seq_len[b].item()) + qb = q[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) + kb = k[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) + vb = v[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) + if num_q != num_kv: + r = num_q // num_kv + kb = kb.repeat_interleave(r, dim=1) + vb = vb.repeat_interleave(r, dim=1) + ob = F.scaled_dot_product_attention(qb, kb, vb, is_causal=True) + parts.append(ob.permute(0, 2, 1, 3).squeeze(0)) + return torch.cat(parts, dim=0) + + +def _get_prefill_block_size(): + """Return the BLOCK size the prefill kernel uses on the current GPU.""" + cap = torch.cuda.get_device_capability() + return 128 if cap[0] >= 8 else 64 + + +def _sparse24_top2(x0, x1, x2, x3): + """Top-2-of-4 mask (same logic as Triton _sparse24_noabs_ops).""" + a1, a2, a3 = x0 > x1, x0 > x2, x0 > x3 + a4, a5, a6 = x1 > x2, x1 > x3, x2 > x3 + m0 = (a2 and a3) or (a1 and a2) or (a1 and a3) + m1 = (not a1 and a5) or (a4 and a5) or (not a1 and a4) + m2 = (not a2 and not a4) or (not a2 and a6) or (not a4 and a6) + m3 = (not a3 and not a5) or (not a3 and not a6) or (not a5 and not a6) + return m0, m1, m2, m3 + + +def _attention_sparse24_ref(q, k, v, scale, bq, ts, skip_diag=True): + """Reference attention with 2:4 sparsity + diagonal skip. [seq, dim] -> [seq, dim].""" + n = q.shape[0] + scores = scale * (q @ k.T) + scores.masked_fill_( + torch.triu(torch.ones(n, n, device=scores.device, dtype=torch.bool), 1), float("-inf") + ) + nqb = (n + bq - 1) // bq + ntiles = (n + ts - 1) // ts + for qb in range(nqb): + qs, qe = qb * bq, min((qb + 1) * bq, n) + for t in range(ntiles): + ks, ke = t * ts, min((t + 1) * ts, n) + if skip_diag and ks < qe and ke > qs: + continue + for row in range(qs, qe): + for g in range((ke - ks) // 4): + c = ks + g * 4 + vals = [scores[row, c + i].item() for i in range(4)] + mask = _sparse24_top2(*vals) + for i in range(4): + if not mask[i]: + scores[row, c + i] = float("-inf") + return F.softmax(scores.float(), dim=-1).to(q.dtype) @ v + + +@pytest.fixture(scope="module") +def tiny_llama_dir(tmp_path_factory): + """Tiny Llama: 2 layers, 64 hidden, 4 q-heads, 2 kv-heads, head_dim=16.""" + from _test_utils.torch.transformers_models import create_tiny_llama_dir + + return create_tiny_llama_dir( + tmp_path_factory.mktemp("tiny_llama"), + with_tokenizer=True, + num_hidden_layers=2, + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=64, + max_position_embeddings=64, + ) + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestUnifiedAttentionVsSdpa: + """Triton unified attention matches PyTorch SDPA for prefill and decode.""" + + @pytest.mark.parametrize( + ("dtype", "num_heads", "num_kv_heads", "head_dim", "tol"), + [ + (torch.float32, 2, 2, 32, 1e-2), + (torch.float16, 4, 2, 64, 2e-2), + ], + ids=["fp32_mha", "fp16_gqa"], + ) + def test_prefill_matches_sdpa(self, dtype, num_heads, num_kv_heads, head_dim, tol): + """Prefill via context_attention_fwd matches SDPA (variable-length batch).""" + seq_lens = [8, 12] + total = sum(seq_lens) + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(123) + q = torch.randn(total, num_heads, head_dim, device="cuda", dtype=dtype) + k = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) + v = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) + locs = torch.tensor([0, seq_lens[0]], device="cuda", dtype=torch.int32) + lens = torch.tensor(seq_lens, device="cuda", dtype=torch.int32) + + o = torch.empty_like(q) + context_attention_fwd( + q, + k, + v, + o, + b_start_loc=locs, + b_seq_len=lens, + max_input_len=max(seq_lens), + is_causal=True, + softmax_scale=scale, + ) + torch.testing.assert_close(o, _sdpa_reference(q, k, v, locs, lens), rtol=tol, atol=tol) + + def test_cross_attention_matches_sdpa(self): + """Non-causal cross-attention: different Q and K/V lengths, matches SDPA.""" + seq_q, seq_k = 6, 10 + num_heads, num_kv_heads, head_dim = 4, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(501) + q = torch.randn(seq_q, num_heads, head_dim, device="cuda", dtype=torch.float32) + k = torch.randn(seq_k, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + v = torch.randn(seq_k, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + + o = torch.empty_like(q) + context_attention_fwd( + q, + k, + v, + o, + b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len=torch.tensor([seq_q], device="cuda", dtype=torch.int32), + max_input_len=seq_q, + is_causal=False, + softmax_scale=scale, + b_start_loc_k=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len_k=torch.tensor([seq_k], device="cuda", dtype=torch.int32), + max_input_len_k=seq_k, + ) + + # Reference: SDPA non-causal + q_ref = q.unsqueeze(0).permute(0, 2, 1, 3) # [1, heads, seq_q, dim] + k_ref = k.unsqueeze(0).permute(0, 2, 1, 3) + v_ref = v.unsqueeze(0).permute(0, 2, 1, 3) + k_ref = k_ref.repeat_interleave(num_heads // num_kv_heads, dim=1) + v_ref = v_ref.repeat_interleave(num_heads // num_kv_heads, dim=1) + o_ref = F.scaled_dot_product_attention(q_ref, k_ref, v_ref, is_causal=False) + o_ref = o_ref.permute(0, 2, 1, 3).squeeze(0) + + torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2) + + def test_decode_matches_sdpa(self): + """Decode via context_attention_fwd(is_causal=False) matches per-sample SDPA.""" + batch = 2 + seq_lens_k = [5, 9] # KV lengths (context + current token) + num_heads, num_kv_heads, head_dim = 4, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(103) + # Q: one token per batch element -> flat [batch, num_heads, head_dim] + q_flat = torch.randn(batch, num_heads, head_dim, device="cuda", dtype=torch.float32) + + # K/V: variable-length, packed into flat tensors + total_kv = sum(seq_lens_k) + k_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + v_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + + cumsum = [0] + for sl in seq_lens_k: + cumsum.append(cumsum[-1] + sl) + b_start_loc_q = torch.arange(batch, device="cuda", dtype=torch.int32) + b_seq_len_q = torch.ones(batch, device="cuda", dtype=torch.int32) + b_start_loc_k = torch.tensor(cumsum[:-1], device="cuda", dtype=torch.int32) + b_seq_len_k = torch.tensor(seq_lens_k, device="cuda", dtype=torch.int32) + + out = torch.empty_like(q_flat) + context_attention_fwd( + q_flat, + k_flat, + v_flat, + out, + b_start_loc=b_start_loc_q, + b_seq_len=b_seq_len_q, + max_input_len=1, + is_causal=False, + softmax_scale=scale, + b_start_loc_k=b_start_loc_k, + b_seq_len_k=b_seq_len_k, + max_input_len_k=max(seq_lens_k), + ) + + for i in range(batch): + sl = seq_lens_k[i] + s = cumsum[i] + qb = q_flat[i : i + 1].unsqueeze(2) # [1, heads, 1, dim] + kb = k_flat[s : s + sl].unsqueeze(0).permute(0, 2, 1, 3) + vb = v_flat[s : s + sl].unsqueeze(0).permute(0, 2, 1, 3) + kb = kb.repeat_interleave(num_heads // num_kv_heads, dim=1) + vb = vb.repeat_interleave(num_heads // num_kv_heads, dim=1) + ref = F.scaled_dot_product_attention(qb, kb, vb, is_causal=False).squeeze(2) + torch.testing.assert_close(out[i : i + 1], ref, rtol=1e-2, atol=1e-2) + + def test_prefill_decode_consistency(self): + """Last token of prefill matches decode output for the same sequence.""" + seq_len = 8 + num_heads, num_kv_heads, head_dim = 2, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(104) + q = torch.randn(seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32) + k = torch.randn(seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + v = torch.randn(seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + + # Prefill (causal) + o_pf = torch.empty_like(q) + context_attention_fwd( + q, + k, + v, + o_pf, + b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), + max_input_len=seq_len, + is_causal=True, + softmax_scale=scale, + ) + + # Decode: last token as query, full K/V (non-causal to attend to all) + q_dec = q[-1:].contiguous() + o_dec = torch.empty_like(q_dec) + context_attention_fwd( + q_dec, + k, + v, + o_dec, + b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len=torch.tensor([1], device="cuda", dtype=torch.int32), + max_input_len=1, + is_causal=False, + softmax_scale=scale, + b_start_loc_k=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len_k=torch.tensor([seq_len], device="cuda", dtype=torch.int32), + max_input_len_k=seq_len, + ) + + torch.testing.assert_close(o_pf[-1:], o_dec, rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSparse24Attention: + """2:4 sparse attention applied inside the Triton kernel.""" + + def test_sparse24_output_differs_from_dense(self): + """Sparse24 enabled produces different (but valid) output vs dense.""" + block = _get_prefill_block_size() + seq_lens = [block * 2, block * 3] + total = sum(seq_lens) + num_heads, num_kv_heads, head_dim = 2, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(789) + q = torch.randn(total, num_heads, head_dim, device="cuda", dtype=torch.float32) + k = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + v = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + locs = torch.tensor([0, seq_lens[0]], device="cuda", dtype=torch.int32) + lens = torch.tensor(seq_lens, device="cuda", dtype=torch.int32) + + kw = { + "b_start_loc": locs, + "b_seq_len": lens, + "max_input_len": max(seq_lens), + "is_causal": True, + "softmax_scale": scale, + } + + o_dense = torch.empty_like(q) + context_attention_fwd(q, k, v, o_dense, apply_sparse24=False, **kw) + o_sparse = torch.empty_like(q) + context_attention_fwd( + q, k, v, o_sparse, apply_sparse24=True, skip_diagonal_blocks=True, **kw + ) + + assert not torch.equal(o_dense, o_sparse), "Sparse should differ from dense" + assert not torch.isnan(o_sparse).any() and not torch.isinf(o_sparse).any() + + def test_sparse24_matches_reference(self): + """Sparse24 with GQA (4 q-heads, 2 kv-heads) matches Python reference.""" + block = _get_prefill_block_size() + seq_len = block * 2 + block // 2 # ensure non-trivial diagonal + off-diagonal tiles + num_heads, num_kv_heads, head_dim = 4, 2, 32 + nqkv = num_heads // num_kv_heads + scale = 1.0 / (head_dim**0.5) + bq, ts = block, block + + torch.manual_seed(303) + q = torch.randn(seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32) + k = torch.randn(seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + v = torch.randn(seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + + o_tri = torch.empty_like(q) + context_attention_fwd( + q, + k, + v, + o_tri, + b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), + max_input_len=seq_len, + is_causal=True, + softmax_scale=scale, + apply_sparse24=True, + skip_diagonal_blocks=True, + ) + + o_ref = torch.empty_like(q) + for h in range(num_heads): + o_ref[:, h] = _attention_sparse24_ref( + q[:, h], + k[:, h // nqkv], + v[:, h // nqkv], + scale, + bq, + ts, + ) + + torch.testing.assert_close(o_tri, o_ref, rtol=5e-2, atol=5e-2) + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSparseAttentionIntegration: + """HF model + mtsa.sparsify integration.""" + + def test_triton_forward_and_generate(self, tiny_llama_dir): + """modelopt_triton attention: prefill logits valid, generate produces tokens.""" + pytest.importorskip("transformers") + from transformers import AutoModelForCausalLM, AutoTokenizer + + model = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + attn_implementation="modelopt_triton", + torch_dtype=torch.bfloat16, + device_map="cuda", + ) + model.eval() + tok = AutoTokenizer.from_pretrained(tiny_llama_dir) + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + + ids = tok("The capital of France is", return_tensors="pt").input_ids.to("cuda") + with torch.no_grad(): + logits = model(input_ids=ids).logits + assert not torch.isnan(logits).any() and not torch.isinf(logits).any() + + with torch.no_grad(): + out = model.generate( + ids, max_new_tokens=5, do_sample=False, pad_token_id=tok.pad_token_id + ) + assert out.shape[1] == ids.shape[1] + 5 + + def test_sparsify_sparse24_produces_valid_output(self, tiny_llama_dir): + """mtsa.sparsify(model, SPARSE24_TRITON) forward produces valid logits.""" + pytest.importorskip("transformers") + from transformers import AutoModelForCausalLM, AutoTokenizer + + import modelopt.torch.sparsity.attention_sparsity as mtsa + from modelopt.torch.sparsity.attention_sparsity.config import SPARSE24_TRITON + + model = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + torch_dtype=torch.bfloat16, + device_map="cuda", + ) + model = mtsa.sparsify(model, SPARSE24_TRITON) + model.eval() + + tok = AutoTokenizer.from_pretrained(tiny_llama_dir) + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + ids = tok("Hello world", return_tensors="pt").input_ids.to("cuda") + + with torch.no_grad(): + logits = model(input_ids=ids).logits + assert not torch.isnan(logits).any() and not torch.isinf(logits).any() + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestBackward: + """Backward pass gradient correctness tests.""" + + def _sdpa_backward_ref(self, q, k, v, scale, is_causal=True): + """Run SDPA forward+backward, return output and gradients.""" + q_ref = q.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) + k_ref = k.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) + v_ref = v.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) + num_q, num_kv = q_ref.shape[1], k_ref.shape[1] + if num_q != num_kv: + r = num_q // num_kv + k_exp = k_ref.repeat_interleave(r, dim=1) + v_exp = v_ref.repeat_interleave(r, dim=1) + else: + k_exp, v_exp = k_ref, v_ref + o_ref = F.scaled_dot_product_attention( + q_ref, k_exp, v_exp, is_causal=is_causal, scale=scale + ) + o_ref.sum().backward() + dq = q_ref.grad.permute(0, 2, 1, 3).squeeze(0) + dk = k_ref.grad.permute(0, 2, 1, 3).squeeze(0) + dv = v_ref.grad.permute(0, 2, 1, 3).squeeze(0) + return o_ref.permute(0, 2, 1, 3).squeeze(0).detach(), dq.detach(), dk.detach(), dv.detach() + + def test_backward_causal_matches_sdpa(self): + """dQ, dK, dV match SDPA backward for causal self-attention.""" + from modelopt.torch.sparsity.attention_sparsity.kernels import context_attention + + seq_len = 16 + num_heads, num_kv_heads, head_dim = 2, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(42) + q = torch.randn( + seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + k = torch.randn( + seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + v = torch.randn( + seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + + o = context_attention( + q, + k, + v, + b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), + max_input_len=seq_len, + is_causal=True, + softmax_scale=scale, + ) + o.sum().backward() + + _, dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref( + q.detach(), k.detach(), v.detach(), scale, is_causal=True + ) + + torch.testing.assert_close(q.grad, dq_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(k.grad, dk_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(v.grad, dv_ref, rtol=1e-2, atol=1e-2) + + def test_backward_gqa(self): + """Backward with GQA (4 q-heads, 2 kv-heads) matches SDPA.""" + from modelopt.torch.sparsity.attention_sparsity.kernels import context_attention + + seq_len = 16 + num_heads, num_kv_heads, head_dim = 4, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(43) + q = torch.randn( + seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + k = torch.randn( + seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + v = torch.randn( + seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + + o = context_attention( + q, + k, + v, + b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), + max_input_len=seq_len, + is_causal=True, + softmax_scale=scale, + ) + o.sum().backward() + + _, dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref( + q.detach(), k.detach(), v.detach(), scale, is_causal=True + ) + + torch.testing.assert_close(q.grad, dq_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(k.grad, dk_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(v.grad, dv_ref, rtol=1e-2, atol=1e-2) + + def test_backward_sparse24_finite(self): + """Backward with sparse24 produces finite, non-zero gradients.""" + from modelopt.torch.sparsity.attention_sparsity.kernels import context_attention + + block = _get_prefill_block_size() + seq_len = block * 2 + block // 2 + num_heads, num_kv_heads, head_dim = 2, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(44) + q = torch.randn( + seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + k = torch.randn( + seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + v = torch.randn( + seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + + o = context_attention( + q, + k, + v, + b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), + max_input_len=seq_len, + is_causal=True, + softmax_scale=scale, + apply_sparse24=True, + skip_diagonal_blocks=True, + ) + o.sum().backward() + + for name, grad in [("dQ", q.grad), ("dK", k.grad), ("dV", v.grad)]: + assert grad is not None, f"{name} gradient is None" + assert not torch.isnan(grad).any(), f"{name} has NaN" + assert not torch.isinf(grad).any(), f"{name} has Inf" + assert grad.abs().sum() > 0, f"{name} is all zeros" + + def test_backward_multi_batch_variable_length(self): + """Multi-batch variable-length causal backward matches per-sample SDPA.""" + from modelopt.torch.sparsity.attention_sparsity.kernels import context_attention + + seq_lens = [8, 12] + total = sum(seq_lens) + num_heads, num_kv_heads, head_dim = 2, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(45) + q = torch.randn( + total, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + k = torch.randn( + total, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + v = torch.randn( + total, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + locs = torch.tensor([0, seq_lens[0]], device="cuda", dtype=torch.int32) + lens = torch.tensor(seq_lens, device="cuda", dtype=torch.int32) + + o = context_attention( + q, + k, + v, + b_start_loc=locs, + b_seq_len=lens, + max_input_len=max(seq_lens), + is_causal=True, + softmax_scale=scale, + ) + o.sum().backward() + + # Per-sample SDPA reference + dq_ref = torch.zeros_like(q) + dk_ref = torch.zeros_like(k) + dv_ref = torch.zeros_like(v) + for b in range(len(seq_lens)): + s, n = int(locs[b].item()), seq_lens[b] + _, dq_b, dk_b, dv_b = self._sdpa_backward_ref( + q.detach()[s : s + n], + k.detach()[s : s + n], + v.detach()[s : s + n], + scale, + is_causal=True, + ) + dq_ref[s : s + n] = dq_b + dk_ref[s : s + n] = dk_b + dv_ref[s : s + n] = dv_b + + torch.testing.assert_close(q.grad, dq_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(k.grad, dk_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(v.grad, dv_ref, rtol=1e-2, atol=1e-2) + + def test_backward_cross_attention(self): + """Non-causal cross-attention backward with different Q and K/V lengths.""" + from modelopt.torch.sparsity.attention_sparsity.kernels import context_attention + + seq_q, seq_k = 6, 10 + num_heads, num_kv_heads, head_dim = 4, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(46) + q = torch.randn( + seq_q, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + k = torch.randn( + seq_k, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + v = torch.randn( + seq_k, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + + o = context_attention( + q, + k, + v, + b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len=torch.tensor([seq_q], device="cuda", dtype=torch.int32), + max_input_len=seq_q, + is_causal=False, + softmax_scale=scale, + b_start_loc_k=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len_k=torch.tensor([seq_k], device="cuda", dtype=torch.int32), + max_input_len_k=seq_k, + ) + o.sum().backward() + + # SDPA reference (non-causal, GQA-expanded) + q_ref = q.detach().clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) + k_ref = k.detach().clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) + v_ref = v.detach().clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) + r = num_heads // num_kv_heads + k_exp = k_ref.repeat_interleave(r, dim=1) + v_exp = v_ref.repeat_interleave(r, dim=1) + o_ref = F.scaled_dot_product_attention(q_ref, k_exp, v_exp, is_causal=False, scale=scale) + o_ref.sum().backward() + + torch.testing.assert_close( + q.grad, q_ref.grad.permute(0, 2, 1, 3).squeeze(0), rtol=1e-2, atol=1e-2 + ) + torch.testing.assert_close( + k.grad, k_ref.grad.permute(0, 2, 1, 3).squeeze(0), rtol=1e-2, atol=1e-2 + ) + torch.testing.assert_close( + v.grad, v_ref.grad.permute(0, 2, 1, 3).squeeze(0), rtol=1e-2, atol=1e-2 + ) + + def test_backward_sparse24_matches_reference(self): + """Sparse24 backward dQ/dK/dV match a Python reference with manual 2:4 masking.""" + from modelopt.torch.sparsity.attention_sparsity.kernels import context_attention + + block = _get_prefill_block_size() + seq_len = block * 2 + block // 2 + num_heads, num_kv_heads, head_dim = 2, 2, 32 + scale = 1.0 / (head_dim**0.5) + bq, ts = block, block + + torch.manual_seed(47) + q_data = torch.randn(seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32) + k_data = torch.randn(seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + v_data = torch.randn(seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + + # Triton backward + q = q_data.clone().requires_grad_(True) + k = k_data.clone().requires_grad_(True) + v = v_data.clone().requires_grad_(True) + o = context_attention( + q, + k, + v, + b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), + max_input_len=seq_len, + is_causal=True, + softmax_scale=scale, + apply_sparse24=True, + skip_diagonal_blocks=True, + ) + o.sum().backward() + + # Python reference backward (per head, using _attention_sparse24_ref logic) + dq_ref = torch.zeros_like(q_data) + dk_ref = torch.zeros_like(k_data) + dv_ref = torch.zeros_like(v_data) + for h in range(num_heads): + kv_h = h // (num_heads // num_kv_heads) + q_h = q_data[:, h].clone().requires_grad_(True) + k_h = k_data[:, kv_h].clone().requires_grad_(True) + v_h = v_data[:, kv_h].clone().requires_grad_(True) + o_h = _attention_sparse24_ref(q_h, k_h, v_h, scale, bq, ts) + o_h.sum().backward() + dq_ref[:, h] = q_h.grad + dk_ref[:, kv_h] += k_h.grad + dv_ref[:, kv_h] += v_h.grad + + torch.testing.assert_close(q.grad, dq_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(k.grad, dk_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(v.grad, dv_ref, rtol=1e-2, atol=1e-2) + + def test_backward_matches_sdpa_all_grads(self): + """All three gradients match SDPA across multiple configs (smoke test).""" + from modelopt.torch.sparsity.attention_sparsity.kernels import context_attention + + configs = [ + (4, 2, 2, 16, True), # small causal + (8, 4, 2, 32, True), # GQA causal + (6, 2, 2, 32, False), # non-causal + ] + for seq_len, num_heads, num_kv_heads, head_dim, is_causal in configs: + scale = 1.0 / (head_dim**0.5) + torch.manual_seed(48) + q = torch.randn( + seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + k = torch.randn( + seq_len, + num_kv_heads, + head_dim, + device="cuda", + dtype=torch.float32, + requires_grad=True, + ) + v = torch.randn( + seq_len, + num_kv_heads, + head_dim, + device="cuda", + dtype=torch.float32, + requires_grad=True, + ) + + o = context_attention( + q, + k, + v, + b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), + max_input_len=seq_len, + is_causal=is_causal, + softmax_scale=scale, + ) + o.sum().backward() + + _, dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref( + q.detach(), k.detach(), v.detach(), scale, is_causal=is_causal + ) + tag = f"seq={seq_len},heads={num_heads}/{num_kv_heads},causal={is_causal}" + torch.testing.assert_close(q.grad, dq_ref, rtol=1e-2, atol=1e-2, msg=f"dQ {tag}") + torch.testing.assert_close(k.grad, dk_ref, rtol=1e-2, atol=1e-2, msg=f"dK {tag}") + torch.testing.assert_close(v.grad, dv_ref, rtol=1e-2, atol=1e-2, msg=f"dV {tag}") + + def test_backward_longer_sequences(self): + """Backward with seq_len=256 exercises multi-tile loops (BLOCK=128).""" + from modelopt.torch.sparsity.attention_sparsity.kernels import context_attention + + seq_len = 256 + num_heads, num_kv_heads, head_dim = 2, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(49) + q = torch.randn( + seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + k = torch.randn( + seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + v = torch.randn( + seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + ) + + o = context_attention( + q, + k, + v, + b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), + max_input_len=seq_len, + is_causal=True, + softmax_scale=scale, + ) + o.sum().backward() + + _, dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref( + q.detach(), k.detach(), v.detach(), scale, is_causal=True + ) + + torch.testing.assert_close(q.grad, dq_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(k.grad, dk_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(v.grad, dv_ref, rtol=1e-2, atol=1e-2) + + def test_forward_backward_matches_forward_only(self): + """context_attention forward output matches context_attention_fwd.""" + from modelopt.torch.sparsity.attention_sparsity.kernels import context_attention + + seq_len = 16 + num_heads, num_kv_heads, head_dim = 2, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(50) + q = torch.randn(seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32) + k = torch.randn(seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + v = torch.randn(seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + locs = torch.tensor([0], device="cuda", dtype=torch.int32) + lens = torch.tensor([seq_len], device="cuda", dtype=torch.int32) + + # Forward-only path + o_fwd = torch.empty_like(q) + context_attention_fwd( + q, + k, + v, + o_fwd, + b_start_loc=locs, + b_seq_len=lens, + max_input_len=seq_len, + is_causal=True, + softmax_scale=scale, + ) + + # Forward+backward path (just use forward output) + o_bwd = context_attention( + q, + k, + v, + b_start_loc=locs, + b_seq_len=lens, + max_input_len=seq_len, + is_causal=True, + softmax_scale=scale, + ) + + torch.testing.assert_close(o_fwd, o_bwd, rtol=1e-5, atol=1e-5)