Skip to content

Comments

Added inter document masking for manual and flash attention.#434

Draft
BlueCrescent wants to merge 7 commits intomainfrom
inter_document_masking_for_attention
Draft

Added inter document masking for manual and flash attention.#434
BlueCrescent wants to merge 7 commits intomainfrom
inter_document_masking_for_attention

Conversation

@BlueCrescent
Copy link
Member

What does this PR do?

Adds inter document masking for manual and flash attention.

General Changes

  • Added prepare_inter_document_masking() to CausalSelfAttention which computes 3D attention masks for manual attention and cu_seqlens for DAO flash attention. The input are the sub sequence lengths for each sequence. Thus, padded sequences are also supported.
  • When provided to the attention's forward() call, inter document masking is applied.
  • Added thorough tests for CausalSelfAttention.
  • Integrated into GPT2Model.
  • TODO: Test GPT2Model, support PP, create corresponding dataloader

Breaking Changes

  • None, if no inter document sequence lengths are provided, the behavior should remain unchanged.

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds inter-document masking support to the GPT-2 attention stack so sequences containing multiple concatenated documents can prevent cross-document attention for both manual attention and DAO flash attention (via varlen).

Changes:

  • Added CausalSelfAttention.prepare_inter_document_masking() and threaded optional masking info through attention execution paths.
  • Implemented DAO flash varlen execution path to support document-wise masking/splitting without padding leakage.
  • Added extensive unit tests for inter-document masking behaviors and edge cases.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 8 comments.

File Description
tests/models/test_causal_self_attention.py Adds comprehensive tests for inter-document masking across manual and DAO flash attention implementations.
src/modalities/models/gpt2/gpt2_model.py Implements inter-document masking preparation, DAO flash varlen execution, and integrates masking into GPT2LLM/GPT2Block/CausalSelfAttention.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +21 to 23
torch.manual_seed(0) # FIXME remove or do within tests?


Copy link

Copilot AI Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting torch.manual_seed(0) at module import time mutates global RNG state for the entire test run, which can make unrelated tests order-dependent. Prefer seeding inside the specific tests that need determinism (or via a fixture) instead of at import scope.

Suggested change
torch.manual_seed(0) # FIXME remove or do within tests?

Copilot uses AI. Check for mistakes.
Comment on lines +523 to +533
device = self.c_proj.weight.device
if self.attention_impl == AttentionImplementation.MANUAL:
batch_size = len(in_batch_seq_lens)
attn_mask = torch.zeros((batch_size, max_seq_len, max_seq_len), dtype=torch.bool, device=device)
for i, doc_seq_lens in enumerate(in_batch_seq_lens):
doc_boundaries = torch.cumsum(torch.tensor([0] + doc_seq_lens, device=device), dim=0)
for j in range(len(doc_boundaries) - 1):
start_idx = doc_boundaries[j]
end_idx = doc_boundaries[j + 1]
attn_mask[i, start_idx:end_idx, start_idx:end_idx] = True
return attn_mask
Copy link

Copilot AI Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prepare_inter_document_masking() validates in_batch_seq_lens for the dao_flash path but not for the manual path. If a batch item's subsequence lengths sum to more than max_seq_len (or contain invalid values), the manual path can produce incorrect masks or raise a low-level indexing error. Consider applying the same validation (e.g., reuse _build_concatenated_lengths_tensor checks) for the manual implementation too.

Copilot uses AI. Check for mistakes.
Comment on lines +788 to 792
attention_masking_information (torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None):
Optional tensor containing masking information for inter-document attention.

Returns:
torch.Tensor: The output tensor.
Copy link

Copilot AI Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

execute_attention() now accepts attention_masking_information, but inter-document masking is only applied in the MANUAL / DAO_FLASH implementations. For PYTORCH_FLASH, the mask is currently silently ignored (it always passes attn_mask=None). To avoid surprising behavior, consider raising NotImplementedError (or asserting the mask is None) when attention_impl is PYTORCH_FLASH and a mask is provided.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant