diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 5757e64a9..e5806f0ac 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -443,6 +443,21 @@ class BatchSamplerConfig(BaseModel): class GPT2LLMCollateFnConfig(BaseModel): sample_key: str target_key: str + sub_seq_lengths_key: str | None = None + eos_token_id: int | None = None + padding_token_id: int | None = None + + @model_validator(mode="after") + def check_sub_seq_lengths_and_eos_token(self) -> "GPT2LLMCollateFnConfig": + if (self.sub_seq_lengths_key is None) != (self.eos_token_id is None): + raise ValueError("Either both or neither of sub_seq_lengths_key and eos_token_id must be provided.") + return self + + @model_validator(mode="after") + def check_padding_token_and_sub_seq_lengths(self) -> "GPT2LLMCollateFnConfig": + if self.padding_token_id is not None and self.sub_seq_lengths_key is None: + raise ValueError("If padding_token_id is provided, sub_seq_lengths_key must also be provided.") + return self class LLMDataLoaderConfig(BaseModel): diff --git a/src/modalities/models/gpt2/collator.py b/src/modalities/models/gpt2/collator.py index f4cf9b531..fd767863e 100644 --- a/src/modalities/models/gpt2/collator.py +++ b/src/modalities/models/gpt2/collator.py @@ -7,16 +7,31 @@ class GPT2LLMCollateFn(CollateFnIF): """GPT2LLMCollateFn class to define a collate function for GPT2 language model.""" - def __init__(self, sample_key: str, target_key: str): + def __init__( + self, + sample_key: str, + target_key: str, + sub_seq_lengths_key: str | None = None, + eos_token_id: int | None = None, + padding_token_id: int | None = None, + ): """ Initializes the Collator object. + If the eos token ID and the sub_seq_lengths_key are provided, + a list[list[int]] representing the sub-sequence lengths will be created. Args: sample_key (str): The key for accessing the sample data. target_key (str): The key for accessing the target data. + sub_seq_lengths_key (str | None): The key for accessing the sub-sequence lengths. + eos_token_id (int | None): The end-of-sequence token ID. + padding_token_id (int | None): The padding token ID. """ self.sample_key = sample_key self.target_key = target_key + self.sub_seq_lengths_key = sub_seq_lengths_key + self.eos_token_id = eos_token_id + self.padding_token_id = padding_token_id def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: """ @@ -33,4 +48,43 @@ def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: sample_tensor = torch.stack([torch.tensor(d[self.sample_key]) for d in batch]) samples = {self.sample_key: sample_tensor[:, :-1]} targets = {self.target_key: sample_tensor[:, 1:]} + if self.sub_seq_lengths_key is not None: + # Determine sub sequence lengths by finding the eos tokens in each sequence in the batch. + sub_seq_lengths = self._compute_sub_sequence_lengths_for_each_sequence(samples[self.sample_key]) + samples[self.sub_seq_lengths_key] = sub_seq_lengths return DatasetBatch(targets=targets, samples=samples) + + def _compute_sub_sequence_lengths_for_each_sequence(self, sample_tensor: torch.Tensor) -> list[list[int]]: + sub_seq_lengths = [] + for seq in sample_tensor: + eos_positions = (seq == self.eos_token_id).nonzero(as_tuple=True)[0] + if len(eos_positions) == 0: + assert ( + self.padding_token_id is None or seq[0] != self.padding_token_id + ), "Sequence starts with padding token" + sub_seq_lengths.append([len(seq)]) + else: + subseq_lengths = self._compute_subsequence_length(seq, eos_positions) + sub_seq_lengths.append(subseq_lengths) + return sub_seq_lengths + + def _compute_subsequence_length(self, seq: torch.Tensor, eos_positions: torch.Tensor) -> list[int]: + # If the last sequence is cut, i.e. does not end on an eos token, + # it should also be included unless the padding token is set and + # the last sequence is just padding. + last_eos_pos = eos_positions[-1].item() + if self._has_cutoff_final_sequence(seq, last_eos_pos): + eos_positions = torch.cat([eos_positions, torch.tensor([len(seq) - 1])]) + # Compute length of each subsequence and add to lengths list. + subseq_lengths = [] + prev_pos = 0 + for pos in eos_positions: + subseq_lengths.append(pos.item() - prev_pos + 1) + prev_pos = pos.item() + 1 + return subseq_lengths + + def _has_cutoff_final_sequence(self, seq: torch.Tensor, last_eos_pos: int) -> bool: + # Assumption: If the first token of the last sequence is padding, so is the rest. + return last_eos_pos < len(seq) - 1 and ( + self.padding_token_id is None or seq[last_eos_pos + 1] != self.padding_token_id + ) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 70f595e67..d2634aa48 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -20,9 +20,10 @@ from modalities.util import parse_enum_by_name try: - from flash_attn import flash_attn_func + from flash_attn import flash_attn_func, flash_attn_varlen_func except ModuleNotFoundError: flash_attn_func = None + flash_attn_varlen_func = None # Logger configuration logger = logging.getLogger(__name__) @@ -372,6 +373,7 @@ class GPT2LLMConfig(BaseModel): use_weight_tying: bool seed: Optional[int] = None enforce_swiglu_hidden_dim_multiple_of: int = 256 + sub_seq_lengths_key: str | None = None @model_validator(mode="after") def check_divisibility(self) -> "GPT2LLMConfig": @@ -501,6 +503,184 @@ def __init__( self.q_norm = None self.k_norm = None + def prepare_inter_document_masking( + self, in_batch_seq_lens: list[list[int]], max_seq_len: int + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, int]: + """ + Prepares the inter-document attention mask based on the input batch sequence lengths. + For manual attention, a 3D attention mask of shape (batch_size, total_seq_len, total_seq_len) is returned. + For flash attention, the cu_seqlens are computed and returned along with the indices + of valid tokens and the maximum sequence length in the batch. + For sdp attention, an exception is raised for now. + + Args: + in_batch_seq_lens (list[list[int]]): A list of lists containing the sequence + lengths for each document in the batch. + max_seq_len (int): The maximum sequence length in the batch. + + Returns: + torch.Tensor | tuple[torch.Tensor, torch.Tensor, int]: The inter-document masking information. + """ + device = self.c_proj.weight.device + if self.attention_impl == AttentionImplementation.MANUAL: + batch_size = len(in_batch_seq_lens) + attn_mask = torch.zeros((batch_size, max_seq_len, max_seq_len), dtype=torch.bool, device=device) + for i, doc_seq_lens in enumerate(in_batch_seq_lens): + doc_boundaries = torch.cumsum(torch.tensor([0] + doc_seq_lens, device=device), dim=0) + for j in range(len(doc_boundaries) - 1): + start_idx = doc_boundaries[j] + end_idx = doc_boundaries[j + 1] + attn_mask[i, start_idx:end_idx, start_idx:end_idx] = True + return attn_mask + if self.attention_impl == AttentionImplementation.DAO_FLASH: + concatenated_lengths = self._build_concatenated_lengths_tensor( + in_batch_seq_lens=in_batch_seq_lens, + max_seq_len=max_seq_len, + device=device, + ) + return self._get_unpad_data_for_concatenated_sequences(concatenated_lengths) + if self.attention_impl == AttentionImplementation.PYTORCH_FLASH: + raise NotImplementedError( + "Inter-document masking is not supported for `pytorch_flash`. " "Use `manual` or `dao_flash`." + ) + raise NotImplementedError( + f"Attention implementation {self.attention_impl} is not supported for inter-document masking." + ) + + @staticmethod + def _build_concatenated_lengths_tensor( + in_batch_seq_lens: list[list[int]], max_seq_len: int, device: torch.device + ) -> torch.Tensor: + """ + Build a tensor of concatenated subsequence lengths for each batch item. + Args: + in_batch_seq_lens: A list of per-batch lists, where each inner list contains + the lengths of subsequences for that batch item. + max_seq_len: The maximum allowed sequence length (number of subsequences and + total length constraints are validated against this value). + device: The torch device on which to allocate the output tensor. + Returns: + A tensor of shape (batch_size, max_seq_len) containing the subsequence lengths + for each batch item, padded with zeros beyond the number of subsequences. + Raises: + ValueError: If a batch item has more subsequences than max_seq_len or if the + sum of its subsequence lengths exceeds max_seq_len. + """ + batch_size = len(in_batch_seq_lens) + concatenated_lengths = torch.zeros((batch_size, max_seq_len), dtype=torch.int32, device=device) + for batch_idx, doc_seq_lens in enumerate(in_batch_seq_lens): + if len(doc_seq_lens) > max_seq_len: + raise ValueError( + f"Number of subsequences ({len(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) " + f"for batch index {batch_idx}." + ) + if sum(doc_seq_lens) > max_seq_len: + raise ValueError( + f"Sum of subsequence lengths ({sum(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) " + f"for batch index {batch_idx}." + ) + if len(doc_seq_lens) > 0: + concatenated_lengths[batch_idx, : len(doc_seq_lens)] = torch.tensor( + doc_seq_lens, dtype=torch.int32, device=device + ) + return concatenated_lengths + + @staticmethod + def _get_unpad_data_for_concatenated_sequences( + attention_mask_in_length: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Compute unpadded indices and cumulative sequence lengths for concatenated sequences. + + This helper operates on a batched tensor of sequence indicators and produces: + a flattened index tensor of all valid (unpadded) positions, the cumulative + sequence lengths over those positions, and the maximum sequence length in + the batch. + + Args: + attention_mask_in_length (torch.Tensor): A 2D tensor of shape + (batch_size, max_seq_len). Non-zero entries indicate valid (unpadded) + positions for each sequence in the batch; zero entries indicate padding. + + Returns: + tuple[torch.Tensor, torch.Tensor, int]: + - indices: 1D tensor of flattened indices for all valid (unpadded) + tokens in the batch (indexing into a tensor of shape + (batch_size * max_seq_len,)). + - cu_seqlens: 1D int32 tensor of cumulative sequence lengths with a + leading zero (shape: num_valid_tokens + 1), suitable for variable- + length attention utilities. + - max_seqlen_in_batch: Maximum number of valid (unpadded) tokens in + any sequence in the batch, as an int. + + Raises: + ValueError: If no valid (non-zero) entries are present in the input + tensor (i.e., all positions are padded). + """ + + length = attention_mask_in_length.sum(dim=-1) + seqlen = attention_mask_in_length.size(-1) + attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand( + len(length), seqlen + ) < length.unsqueeze(1) + seqlens_in_batch = attention_mask_in_length[attention_mask_in_length > 0] + if seqlens_in_batch.numel() == 0: + raise ValueError("No subsequence lengths provided for inter-document masking.") + indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = int(seqlens_in_batch.max().item()) + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return indices, cu_seqlens, max_seqlen_in_batch + + @classmethod + def _execute_dao_flash_with_inter_document_masking( + cls, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout: float, + attention_masking_information: tuple[torch.Tensor, torch.Tensor, int], + ) -> torch.Tensor: + if flash_attn_varlen_func is None: + raise NotImplementedError( + "ERROR! Dao Flash Attention varlen kernel is not available. " "Install flash-attn with varlen support." + ) + + indices, cu_seqlens, max_seqlen = attention_masking_information + + batch_size, seq_len, n_head_q, head_dim = q.shape + n_head_kv = k.shape[2] + + q_flat = q.reshape(batch_size * seq_len, n_head_q, head_dim) + k_flat = k.reshape(batch_size * seq_len, n_head_kv, head_dim) + v_flat = v.reshape(batch_size * seq_len, n_head_kv, head_dim) + + q_unpad = q_flat.index_select(0, indices) + k_unpad = k_flat.index_select(0, indices) + v_unpad = v_flat.index_select(0, indices) + + y_unpad = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=dropout, + causal=True, + softmax_scale=None, + window_size=(-1, -1), + ) + + y = torch.zeros( + (batch_size * seq_len, n_head_q, head_dim), + dtype=y_unpad.dtype, + device=y_unpad.device, + ) + y.index_copy_(0, indices, y_unpad) + y = y.reshape(batch_size, seq_len, n_head_q, head_dim) + return y + def projection(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Applies projections to the input tensor to get queries, keys, and values. @@ -600,6 +780,7 @@ def execute_attention( v: torch.Tensor, dropout: float, attention_impl: AttentionImplementation, + attention_masking_information: torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None = None, ) -> torch.Tensor: """ Executes attention mechanism based on the specified implementation. @@ -611,6 +792,8 @@ def execute_attention( v (torch.Tensor): The value tensor. dropout (float): The dropout rate. attention_impl (AttentionImplementation): The attention implementation to use. + attention_masking_information (torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None): + Optional tensor containing masking information for inter-document attention. Returns: torch.Tensor: The output tensor. @@ -624,7 +807,7 @@ def execute_attention( query=q, key=k, value=v, - attn_mask=None, + attn_mask=attention_masking_information, dropout_p=dropout, is_causal=True, ) # (B, nh_q, T, hd) @@ -646,23 +829,38 @@ def execute_attention( # Note, that the library is not required for the CPU-only tests. if flash_attn_func is None: raise NotImplementedError("ERROR! Dao Flash Attention is not installed.") - # the next three lines are only needed for flash-attn from Daio Lab + # the next three lines are only needed for flash-attn from Dao Lab q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd) k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd) v = v.transpose(1, 2).contiguous() # (B, T, nh_kv, hd) - y = flash_attn_func( - q, k, v, dropout_p=dropout, causal=True, softmax_scale=None, window_size=(-1, -1) - ) # (B, T, nh_q, hd) + if attention_masking_information is None: + y = flash_attn_func( + q, k, v, dropout_p=dropout, causal=True, softmax_scale=None, window_size=(-1, -1) + ) # (B, T, nh_q, hd) + else: + y = cls._execute_dao_flash_with_inter_document_masking( + q=q, + k=k, + v=v, + dropout=dropout, + attention_masking_information=attention_masking_information, + ) else: raise NotImplementedError(f"Attention implementation {attention_impl} not supported") return y # (B, T, nh_q, hd) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + attention_masking_information: torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None = None, + ) -> torch.Tensor: """ Forward pass of the CausalSelfAttention module. Args: x (torch.Tensor): Input tensor of shape (B, T, n_embd) + attention_masking_information (torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None): + Optional tensor containing masking information for inter-document attention. Returns: torch.Tensor: Output tensor of shape (B, T, n_embd), representing the output projection. @@ -675,7 +873,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.q_norm is not None and self.k_norm is not None: q = self.q_norm(q) k = self.k_norm(k) - y = CausalSelfAttention.execute_attention(q, k, v, self.dropout, self.attention_impl) # (B, T, nh_q, hd) + y = CausalSelfAttention.execute_attention( + q, k, v, self.dropout, self.attention_impl, attention_masking_information + ) # (B, T, nh_q, hd) y = y.reshape(B, T, -1) # (B, T, n_embd), re-assemble all head outputs side by side return self.resid_dropout(self.c_proj(y)) # (B, T, n_embd), output projection @@ -798,17 +998,23 @@ def _check_ffn_hidden_dim(self, n_embd: int, ffn_hidden: int) -> None: f"but got `n_embd = {n_embd}` and `ffn_hidden = {ffn_hidden}`." ) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + attention_masking_information: torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None = None, + ) -> torch.Tensor: """ Forward pass of the GPT2Block. Args: x (torch.Tensor): Input tensor. + attention_masking_information (torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None): + Attention masking information. Returns: torch.Tensor: Output tensor. """ - x = x + self.attn(self.attention_norm(x)) + x = x + self.attn(self.attention_norm(x), attention_masking_information=attention_masking_information) x = x + self.mlp(self.ffn_norm(x)) return x @@ -839,6 +1045,7 @@ def __init__( use_weight_tying: bool, seed: Optional[int] = None, enforce_swiglu_hidden_dim_multiple_of: int = 256, + sub_seq_lengths_key: str | None = None, ): """ Initializes the GPT2LLM object. @@ -867,6 +1074,8 @@ def __init__( enforce_swiglu_hidden_dim_multiple_of (int): Enforces the hidden dimension in the SwiGLU layer to be a multiple of this value. Note that this is only relevant if the activation_type is SwiGLU. Defaults to 256. + sub_seq_lengths_key (str, optional): The key for sub sequence lengths to be + used for inter document masking. """ weight_decay_groups = { "linear": [".attn", ".mlp", ".lm_head.weight"], @@ -876,6 +1085,7 @@ def __init__( super().__init__(weight_decay_groups=weight_decay_groups, seed=seed) self.sample_key = sample_key self.prediction_key = prediction_key + self.sub_seq_lengths_key = sub_seq_lengths_key self.sequence_length = sequence_length self.n_embd = n_embd self.n_layer = n_layer @@ -943,13 +1153,14 @@ def __init__( ) # https://paperswithcode.com/method/weight-tying @overload - def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor | list[list[int]]]) -> dict[str, torch.Tensor]: """ Forward pass of the GPT2LLM module. Args: - inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. + inputs (dict[str, torch.Tensor | list[list[int]]]): A dictionary containing input tensors. - sample_key (str): Key for the input tensor containing token ids. + - sub_seq_lengths_key (str, optional): Key for the input tensor containing subsequence lengths. Returns: dict[str, torch.Tensor]: A dictionary containing output tensors. @@ -970,27 +1181,35 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ ... - def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor] | torch.Tensor: + def forward( + self, inputs: dict[str, torch.Tensor | list[list[int]]] | torch.Tensor + ) -> dict[str, torch.Tensor] | torch.Tensor: """ Forward pass of the GPT2LLM module. Args: - inputs (dict[str, torch.Tensor] | torch.Tensor): Input data. + inputs (dict[str, torch.Tensor | list[list[int]]] | torch.Tensor): Input data. Returns: dict[str, torch.Tensor] | torch.Tensor: Model output. """ if isinstance(inputs, dict): - return {self.prediction_key: self.forward_impl(inputs[self.sample_key])} + return { + self.prediction_key: self.forward_impl( + inputs[self.sample_key], sub_seq_lengths=inputs.get(self.sub_seq_lengths_key) + ) + } else: return self.forward_impl(inputs) - def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor: + def forward_impl(self, inputs: torch.Tensor, sub_seq_lengths: list[list[int]] | None = None) -> torch.Tensor: """ Forward pass implementation of the GPT2LLM module. Args: inputs (torch.Tensor): A tensor containing input token ids. + sub_seq_lengths (list[list[int]], optional): The lengths of the subsequences of each sequence + in the batch. To be used for inter document masking. Returns: torch.Tensor: A tensor containing output logits. @@ -1013,8 +1232,16 @@ def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor: # TODO: use drop out also without absolute position embedding? h = self.transformer.drop(h) if hasattr(self.transformer, "drop") else h + # TODO: Handle this in case of pipeline parallelism. + if sub_seq_lengths is not None: + attention_masking_information = self.transformer.h["0"].attn.prepare_inter_document_masking( + in_batch_seq_lens=sub_seq_lengths, max_seq_len=seq_len + ) + else: + attention_masking_information = None + for layer_idx in self.transformer.h: - h = self.transformer.h[layer_idx](h) + h = self.transformer.h[layer_idx](h, attention_masking_information=attention_masking_information) h = self.transformer.lm_head_norm(h) if hasattr(self.transformer, "lm_head_norm") else h h = self.transformer.lm_head(h) if hasattr(self.transformer, "lm_head") else h return h @@ -1047,19 +1274,37 @@ def manual_scaled_dot_product_attention( attn_bias = torch.zeros( L, S, dtype=query.dtype, device=query.device ) # device added (not part of the original code) + fully_masked = None + if attn_mask is not None and attn_mask.dim() == 3: + attn_bias = attn_bias.unsqueeze(0).repeat(attn_mask.size(0), 1, 1) if is_causal: - assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) # device added - attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query.dtype) - - if attn_mask is not None: + if attn_mask is None: + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + elif attn_mask.dtype == torch.bool: + if attn_mask.dim() == 3: + combined_mask = temp_mask.unsqueeze(0) & attn_mask + else: + combined_mask = temp_mask & attn_mask + fully_masked = ~combined_mask.any(dim=-1) + attn_bias.masked_fill_(combined_mask.logical_not(), float("-inf")) + else: + if attn_mask.dim() == 3: + temp_mask = temp_mask.unsqueeze(0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias += attn_mask + elif attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask + attn_bias = attn_bias.to(query.dtype) attn_weight = query @ key.transpose(-2, -1) * scale_factor + if attn_bias.dim() == 3: + attn_bias = attn_bias.unsqueeze(1) attn_weight += attn_bias attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + if fully_masked is not None and attn_weight.dim() == 4: + attn_weight = attn_weight.masked_fill(fully_masked.unsqueeze(1).unsqueeze(-1), 0.0) return attn_weight @ value diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 142aef920..5ad233d81 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -617,6 +617,7 @@ def get_gpt2_model( use_meta_device: Optional[bool] = False, seed: Optional[int] = None, enforce_swiglu_hidden_dim_multiple_of: int = 256, + sub_seq_lengths_key: str | None = None, ) -> GPT2LLM: config = dict( sample_key=sample_key, @@ -640,6 +641,7 @@ def get_gpt2_model( seed=seed, use_weight_tying=use_weight_tying, enforce_swiglu_hidden_dim_multiple_of=enforce_swiglu_hidden_dim_multiple_of, + sub_seq_lengths_key=sub_seq_lengths_key, ) if use_meta_device and use_weight_tying: raise ValueError( diff --git a/tests/models/test_causal_self_attention.py b/tests/models/test_causal_self_attention.py index 6da485462..d36162be5 100644 --- a/tests/models/test_causal_self_attention.py +++ b/tests/models/test_causal_self_attention.py @@ -1,43 +1,24 @@ """ -Note: test_attention_types_approximate_equality can print the output of different attention implementations. +Note: test_attention_types_approximate_equality can print the output of different attention implementations. To do so, turn on verbose and run 'pytest tests/models/test_causal_self_attention.py -s' """ + from copy import deepcopy import pytest import torch +import modalities.models.gpt2.gpt2_model as gpt2_model from modalities.models.gpt2.gpt2_model import ( AttentionConfig, CausalSelfAttention, LayerNorms, LayerNormWrapperConfig, PytorchRMSLayerNormConfig, + flash_attn_varlen_func, ) -torch.manual_seed(0) - - -def _get_random_input_seq(embedding_shape): - flash_attn_supported_dtype = torch.bfloat16 - return torch.rand(size=embedding_shape, dtype=flash_attn_supported_dtype).cuda() - - -def _get_random_attention_layer(n_head_q, n_head_kv, n_embd, attention_impl, attention_config): - self_attention_layer = CausalSelfAttention( - n_head_q=n_head_q, - n_head_kv=n_head_kv, - n_embd=n_embd, - bias=False, - dropout=0.0, - attention_config=attention_config, - attention_impl=attention_impl, - ).cuda() - self_attention_layer.q_attn = self_attention_layer.q_attn.bfloat16() - self_attention_layer.k_attn = self_attention_layer.k_attn.bfloat16() - self_attention_layer.v_attn = self_attention_layer.v_attn.bfloat16() - self_attention_layer.c_proj = self_attention_layer.c_proj.bfloat16() - return self_attention_layer +torch.manual_seed(0) # FIXME remove or do within tests? @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") @@ -272,3 +253,581 @@ def test_qk_norm(n_head_q, n_head_kv, n_embd, attention_impl): assert output_no_norm.shape == output_with_norm.shape == embedding_shape assert not torch.allclose(output_no_norm, output_with_norm, atol=1e-6) + + +def test_inter_document_masking_manual_mask_shape_and_blocks(): + attention_config = AttentionConfig(qkv_transforms=[]) + attention_layer = _get_identity_attention_layer( + n_head_q=2, + n_head_kv=2, + n_embd=4, + attention_impl="manual", + attention_config=attention_config, + ) + + mask = attention_layer.prepare_inter_document_masking(in_batch_seq_lens=[[2, 1], [1, 2]], max_seq_len=3) + + expected_batch_0 = torch.tensor( + [ + [True, True, False], + [True, True, False], + [False, False, True], + ] + ) + expected_batch_1 = torch.tensor( + [ + [True, False, False], + [False, True, True], + [False, True, True], + ] + ) + + assert mask.shape == (2, 3, 3) + torch.testing.assert_close(mask[0].cpu(), expected_batch_0) + torch.testing.assert_close(mask[1].cpu(), expected_batch_1) + + +def test_inter_document_masking_manual_forward_allows_mask(): + attention_config = AttentionConfig(qkv_transforms=[]) + attention_layer = _get_identity_attention_layer( + n_head_q=2, + n_head_kv=2, + n_embd=4, + attention_impl="manual", + attention_config=attention_config, + ) + + inputs = torch.rand(1, 5, 4) + mask = attention_layer.prepare_inter_document_masking(in_batch_seq_lens=[[2, 3]], max_seq_len=5) + + output_masked = attention_layer(inputs, attention_masking_information=mask) + output_doc_1 = attention_layer(inputs[:, :2, :]) + output_doc_2 = attention_layer(inputs[:, 2:, :]) + output_reference = torch.cat([output_doc_1, output_doc_2], dim=1) + + torch.testing.assert_close(output_masked, output_reference) + + +def test_inter_document_masking_manual_mask_symmetry_and_blocks(): + """ + Test to ensure that the inter-document masking is symmetric and correctly blocks attention + between different documents within a batch. + """ + attention_config = AttentionConfig(qkv_transforms=[]) + attention_layer = _get_identity_attention_layer( + n_head_q=2, + n_head_kv=2, + n_embd=4, + attention_impl="manual", + attention_config=attention_config, + ) + + in_batch_seq_lens = [[2, 1, 3], [1, 2, 1]] + mask = attention_layer.prepare_inter_document_masking(in_batch_seq_lens=in_batch_seq_lens, max_seq_len=6) + + assert mask.shape == (2, 6, 6) + for batch_index, doc_seq_lens in enumerate(in_batch_seq_lens): + expected = torch.zeros((6, 6), dtype=torch.bool) + cursor = 0 + for length in doc_seq_lens: + expected[cursor : cursor + length, cursor : cursor + length] = True + cursor += length + torch.testing.assert_close(mask[batch_index].cpu(), expected) + torch.testing.assert_close(mask[batch_index].cpu(), mask[batch_index].cpu().transpose(0, 1)) + + +def test_inter_document_masking_manual_handles_empty_docs(): + attention_config = AttentionConfig(qkv_transforms=[]) + attention_layer = _get_identity_attention_layer( + n_head_q=2, + n_head_kv=2, + n_embd=4, + attention_impl="manual", + attention_config=attention_config, + ) + + mask = attention_layer.prepare_inter_document_masking(in_batch_seq_lens=[[2, 1], []], max_seq_len=3) + + assert mask.shape == (2, 3, 3) + torch.testing.assert_close(mask[1].cpu(), torch.zeros((3, 3), dtype=torch.bool)) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") +def test_inter_document_masking_device_and_dtype_propagation(): + attention_config = AttentionConfig(qkv_transforms=[]) + manual_layer = _get_identity_attention_layer( + n_head_q=2, + n_head_kv=2, + n_embd=4, + attention_impl="manual", + attention_config=attention_config, + ).cuda() + + manual_mask = manual_layer.prepare_inter_document_masking(in_batch_seq_lens=[[2, 1]], max_seq_len=3) + + assert manual_mask.device == manual_layer.c_proj.weight.device + assert manual_mask.dtype == torch.bool + + dao_layer = _get_identity_attention_layer( + n_head_q=2, + n_head_kv=2, + n_embd=4, + attention_impl="dao_flash", + attention_config=attention_config, + ).cuda() + + indices, cu_seqlens, max_seqlen = dao_layer.prepare_inter_document_masking( + in_batch_seq_lens=[[2, 1]], max_seq_len=3 + ) + + assert indices.device == dao_layer.c_proj.weight.device + assert cu_seqlens.device == dao_layer.c_proj.weight.device + assert indices.dtype == torch.int64 + assert cu_seqlens.dtype == torch.int32 + assert max_seqlen == 2 + + +def test_inter_document_masking_manual_float_mask_matches_bool(): + attention_config = AttentionConfig(qkv_transforms=[]) + attention_layer = _get_identity_attention_layer( + n_head_q=2, + n_head_kv=2, + n_embd=4, + attention_impl="manual", + attention_config=attention_config, + ) + + inputs = torch.rand(1, 3, 4) + bool_mask = torch.tensor( + [ + [ + [True, False, False], + [True, True, False], + [False, True, True], + ] + ] + ) + float_mask = torch.where(bool_mask, torch.tensor(0.0), torch.tensor(float("-inf"))).to(inputs.dtype) + + output_bool = attention_layer(inputs, attention_masking_information=bool_mask) + output_float = attention_layer(inputs, attention_masking_information=float_mask) + + torch.testing.assert_close(output_bool, output_float) + + +def test_inter_document_masking_dao_flash_empty_cases(): + attention_config = AttentionConfig(qkv_transforms=[]) + attention_layer = _get_identity_attention_layer( + n_head_q=2, + n_head_kv=2, + n_embd=4, + attention_impl="dao_flash", + attention_config=attention_config, + ) + + indices, cu_seqlens, max_seqlen = attention_layer.prepare_inter_document_masking( + in_batch_seq_lens=[[2, 1], []], max_seq_len=3 + ) + assert indices.numel() == 3 + assert cu_seqlens.tolist() == [0, 2, 3] + assert max_seqlen == 2 + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") +@pytest.mark.skipif(flash_attn_varlen_func is None, reason="This test requires flash-attn varlen support.") +def test_inter_document_masking_dao_flash_empty_docs_forward(): + attention_config = AttentionConfig(qkv_transforms=[]) + dao_layer = _get_random_attention_layer( + n_head_q=2, + n_head_kv=2, + n_embd=4, + attention_impl="dao_flash", + attention_config=attention_config, + ) + + inputs = _get_random_input_seq((2, 3, 4)) + masking = dao_layer.prepare_inter_document_masking(in_batch_seq_lens=[[2, 1], []], max_seq_len=3) + output = dao_layer(inputs, attention_masking_information=masking) + + torch.testing.assert_close(output[1], torch.zeros_like(output[1])) + + +def test_inter_document_masking_dao_flash_validation_errors(): + attention_config = AttentionConfig(qkv_transforms=[]) + attention_layer = _get_identity_attention_layer( + n_head_q=2, + n_head_kv=2, + n_embd=4, + attention_impl="dao_flash", + attention_config=attention_config, + ) + + with pytest.raises(ValueError): + attention_layer.prepare_inter_document_masking(in_batch_seq_lens=[[1, 1, 1, 1]], max_seq_len=3) + + with pytest.raises(ValueError): + attention_layer.prepare_inter_document_masking(in_batch_seq_lens=[[2, 2]], max_seq_len=3) + + with pytest.raises(ValueError): + attention_layer.prepare_inter_document_masking(in_batch_seq_lens=[[], []], max_seq_len=3) + + +def test_inter_document_masking_pytorch_flash_not_supported(): + attention_config = AttentionConfig(qkv_transforms=[]) + attention_layer = _get_identity_attention_layer( + n_head_q=2, + n_head_kv=2, + n_embd=4, + attention_impl="pytorch_flash", + attention_config=attention_config, + ) + + with pytest.raises(NotImplementedError): + attention_layer.prepare_inter_document_masking(in_batch_seq_lens=[[1, 1]], max_seq_len=2) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") +@pytest.mark.skipif(flash_attn_varlen_func is None, reason="This test requires flash-attn varlen support.") +@pytest.mark.parametrize( + "masked_attn_type, docwise_attn_type", + [ + ("dao_flash", "manual"), + ("dao_flash", "dao_flash"), + ("manual", "manual"), + ("manual", "dao_flash"), + ], +) +def test_inter_document_masking_matches_docwise_attention(masked_attn_type, docwise_attn_type): + torch.manual_seed(0) + masked_layer, docwise_layer = _build_matching_attention_layers( + masked_attn_type=masked_attn_type, docwise_attn_type=docwise_attn_type + ) + + inputs = _get_random_input_seq((1, 5, 16)) + mask = masked_layer.prepare_inter_document_masking(in_batch_seq_lens=[[2, 3]], max_seq_len=5) + + output_masked = masked_layer(inputs, attention_masking_information=mask) + output_doc_1 = docwise_layer(inputs[:, :2, :]) + output_doc_2 = docwise_layer(inputs[:, 2:, :]) + output_reference = torch.cat([output_doc_1, output_doc_2], dim=1) + + torch.testing.assert_close(output_masked, output_reference, atol=2.5e-3, rtol=0.016) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") +@pytest.mark.skipif(flash_attn_varlen_func is None, reason="This test requires flash-attn varlen support.") +@pytest.mark.parametrize( + "masked_attn_type, docwise_attn_type", + [ + ("dao_flash", "manual"), + ("dao_flash", "dao_flash"), + ("manual", "manual"), + ("manual", "dao_flash"), + ], +) +def test_inter_document_masking_matches_docwise_attention_gqa(masked_attn_type, docwise_attn_type): + torch.manual_seed(0) + masked_layer, docwise_layer = _build_matching_attention_layers( + masked_attn_type=masked_attn_type, docwise_attn_type=docwise_attn_type, n_head_kv=2 + ) + + inputs = _get_random_input_seq((1, 6, 16)) + mask = masked_layer.prepare_inter_document_masking(in_batch_seq_lens=[[1, 2, 3]], max_seq_len=6) + + output_masked = masked_layer(inputs, attention_masking_information=mask) + output_doc_1 = docwise_layer(inputs[:, :1, :]) + output_doc_2 = docwise_layer(inputs[:, 1:3, :]) + output_doc_3 = docwise_layer(inputs[:, 3:, :]) + output_reference = torch.cat([output_doc_1, output_doc_2, output_doc_3], dim=1) + + torch.testing.assert_close(output_masked, output_reference, atol=2.5e-3, rtol=0.016) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") +@pytest.mark.skipif(flash_attn_varlen_func is None, reason="This test requires flash-attn varlen support.") +def test_inter_document_masking_manual_matches_dao_flash_with_masks(): + torch.manual_seed(0) + dao_layer, manual_layer = _build_matching_dao_and_manual_attention() + + inputs = _get_random_input_seq((2, 5, 16)) + dao_mask = dao_layer.prepare_inter_document_masking(in_batch_seq_lens=[[2, 3], [1, 1, 2]], max_seq_len=5) + manual_mask = manual_layer.prepare_inter_document_masking(in_batch_seq_lens=[[2, 3], [1, 1, 2]], max_seq_len=5) + + output_dao = dao_layer(inputs, attention_masking_information=dao_mask) + output_manual = manual_layer(inputs, attention_masking_information=manual_mask) + + torch.testing.assert_close(output_dao, output_manual, atol=2.5e-3, rtol=0.016) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") +@pytest.mark.skipif(flash_attn_varlen_func is None, reason="This test requires flash-attn varlen support.") +def test_inter_document_masking_dao_flash_blocks_cross_doc_leakage(): + torch.manual_seed(0) + dao_layer, _ = _build_matching_dao_and_manual_attention() + + inputs = torch.zeros((1, 6, 16), dtype=torch.bfloat16, device="cuda") + inputs[:, :2, :] = 1000.0 + + mask = dao_layer.prepare_inter_document_masking(in_batch_seq_lens=[[2, 4]], max_seq_len=6) + output_masked = dao_layer(inputs, attention_masking_information=mask) + output_unmasked = dao_layer(inputs) + + assert not torch.allclose(output_masked[:, 2:, :], output_unmasked[:, 2:, :], atol=1e-3, rtol=1e-3) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") +@pytest.mark.skipif(flash_attn_varlen_func is None, reason="This test requires flash-attn varlen support.") +def test_inter_document_masking_dao_flash_matches_manual_with_batchwise_splits(): + torch.manual_seed(0) + dao_layer, manual_layer = _build_matching_dao_and_manual_attention() + + inputs = _get_random_input_seq((2, 5, 16)) + sub_seq_lengths = [[2, 3], [1, 1, 3]] + dao_mask = dao_layer.prepare_inter_document_masking(in_batch_seq_lens=sub_seq_lengths, max_seq_len=5) + manual_mask = manual_layer.prepare_inter_document_masking(in_batch_seq_lens=sub_seq_lengths, max_seq_len=5) + + output_dao = dao_layer(inputs, attention_masking_information=dao_mask) + output_manual = manual_layer(inputs, attention_masking_information=manual_mask) + + torch.testing.assert_close(output_dao, output_manual, atol=2.5e-3, rtol=0.016) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") +@pytest.mark.skipif(flash_attn_varlen_func is None, reason="This test requires flash-attn varlen support.") +def test_inter_document_masking_with_padding_matches_manual_on_valid_tokens(): + torch.manual_seed(0) + dao_layer, manual_layer = _build_matching_dao_and_manual_attention() + + inputs = _get_random_input_seq((2, 6, 16)) + sub_seq_lengths = [[2, 2], [1, 1, 2]] + valid_lengths = _sum_lengths_per_batch(sub_seq_lengths) + + dao_mask = dao_layer.prepare_inter_document_masking(in_batch_seq_lens=sub_seq_lengths, max_seq_len=6) + manual_mask = manual_layer.prepare_inter_document_masking(in_batch_seq_lens=sub_seq_lengths, max_seq_len=6) + + output_dao = dao_layer(inputs, attention_masking_information=dao_mask) + output_manual = manual_layer(inputs, attention_masking_information=manual_mask) + + for batch_index, valid_len in enumerate(valid_lengths): + torch.testing.assert_close( + output_dao[batch_index, :valid_len, :], + output_manual[batch_index, :valid_len, :], + atol=2.5e-3, + rtol=0.016, + ) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") +@pytest.mark.skipif(flash_attn_varlen_func is None, reason="This test requires flash-attn varlen support.") +def test_inter_document_masking_with_padding_zeroes_dao_padded_outputs(): + torch.manual_seed(0) + dao_layer, _ = _build_matching_dao_and_manual_attention() + + inputs = _get_random_input_seq((2, 6, 16)) + sub_seq_lengths = [[2, 1], [1, 1, 1]] + valid_lengths = _sum_lengths_per_batch(sub_seq_lengths) + + dao_mask = dao_layer.prepare_inter_document_masking(in_batch_seq_lens=sub_seq_lengths, max_seq_len=6) + output_dao = dao_layer(inputs, attention_masking_information=dao_mask) + + for batch_index, valid_len in enumerate(valid_lengths): + if valid_len < inputs.size(1): + torch.testing.assert_close( + output_dao[batch_index, valid_len:, :], + torch.zeros_like(output_dao[batch_index, valid_len:, :]), + ) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") +@pytest.mark.skipif(flash_attn_varlen_func is None, reason="This test requires flash-attn varlen support.") +def test_inter_document_masking_dao_flash_padded_gradients_are_zero(): + """ + Test to ensure that the gradients of padded tokens are zero when using DAO flash attention. + This is tested by backpropagating through the padded outputs and checking the gradients. + """ + torch.manual_seed(0) + dao_layer, _ = _build_matching_dao_and_manual_attention() + + inputs = _get_random_input_seq((2, 5, 16)).requires_grad_(True) + sub_seq_lengths = [[2, 1], [1, 1]] + valid_lengths = _sum_lengths_per_batch(sub_seq_lengths) + + dao_mask = dao_layer.prepare_inter_document_masking(in_batch_seq_lens=sub_seq_lengths, max_seq_len=5) + output_dao = dao_layer(inputs, attention_masking_information=dao_mask) + + loss = 0.0 + for batch_index, valid_len in enumerate(valid_lengths): + if valid_len < inputs.size(1): + loss = loss + output_dao[batch_index, valid_len:, :].sum() + + loss.backward() + torch.testing.assert_close(inputs.grad, torch.zeros_like(inputs.grad)) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") +@pytest.mark.skipif(flash_attn_varlen_func is None, reason="This test requires flash-attn varlen support.") +def test_inter_document_masking_dao_flash_handles_single_token_docs(): + torch.manual_seed(0) + dao_layer, manual_layer = _build_matching_dao_and_manual_attention() + + inputs = _get_random_input_seq((1, 5, 16)) + sub_seq_lengths = [[1, 1, 3]] + dao_mask = dao_layer.prepare_inter_document_masking(in_batch_seq_lens=sub_seq_lengths, max_seq_len=5) + manual_mask = manual_layer.prepare_inter_document_masking(in_batch_seq_lens=sub_seq_lengths, max_seq_len=5) + + output_dao = dao_layer(inputs, attention_masking_information=dao_mask) + output_manual = manual_layer(inputs, attention_masking_information=manual_mask) + + torch.testing.assert_close(output_dao, output_manual, atol=2.5e-3, rtol=0.016) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") +@pytest.mark.skipif(flash_attn_varlen_func is None, reason="This test requires flash-attn varlen support.") +def test_inter_document_masking_dao_flash_randomized_splits(): + torch.manual_seed(0) + dao_layer, manual_layer = _build_matching_dao_and_manual_attention() + generator = torch.Generator().manual_seed(123) + + inputs = _get_random_input_seq((1, 6, 16)) + for _ in range(3): + sub_seq_lengths = [_generate_sub_seq_lengths(total_len=6, max_chunk=3, generator=generator)] + dao_mask = dao_layer.prepare_inter_document_masking(in_batch_seq_lens=sub_seq_lengths, max_seq_len=6) + manual_mask = manual_layer.prepare_inter_document_masking(in_batch_seq_lens=sub_seq_lengths, max_seq_len=6) + + output_dao = dao_layer(inputs, attention_masking_information=dao_mask) + output_manual = manual_layer(inputs, attention_masking_information=manual_mask) + + torch.testing.assert_close(output_dao, output_manual, atol=2.5e-3, rtol=0.016) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") +def test_inter_document_masking_dao_flash_passes_expected_unpad_data(monkeypatch): + attention_config = AttentionConfig(qkv_transforms=[]) + dao_layer = _get_random_attention_layer( + n_head_q=4, + n_head_kv=4, + n_embd=16, + attention_impl="dao_flash", + attention_config=attention_config, + ) + inputs = _get_random_input_seq((1, 5, 16)) + expected_masking = dao_layer.prepare_inter_document_masking(in_batch_seq_lens=[[2, 3]], max_seq_len=5) + expected_indices, expected_cu_seqlens, expected_max_seqlen = expected_masking + + captured = {} + + def fake_flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + causal, + softmax_scale, + window_size, + ): + captured["unpad_len"] = q_unpad.shape[0] + captured["cu_seqlens_q"] = cu_seqlens_q.detach().cpu() + captured["cu_seqlens_k"] = cu_seqlens_k.detach().cpu() + captured["max_seqlen_q"] = max_seqlen_q + captured["max_seqlen_k"] = max_seqlen_k + return torch.zeros_like(q_unpad) + + monkeypatch.setattr(gpt2_model, "flash_attn_func", object()) + monkeypatch.setattr(gpt2_model, "flash_attn_varlen_func", fake_flash_attn_varlen_func) + + output = dao_layer(inputs, attention_masking_information=expected_masking) + + assert output.shape == inputs.shape + assert captured["unpad_len"] == expected_indices.numel() + assert captured["cu_seqlens_q"].tolist() == expected_cu_seqlens.detach().cpu().tolist() + assert captured["cu_seqlens_k"].tolist() == expected_cu_seqlens.detach().cpu().tolist() + assert captured["max_seqlen_q"] == expected_max_seqlen + assert captured["max_seqlen_k"] == expected_max_seqlen + + +def _build_matching_dao_and_manual_attention(n_head_kv: int = 4): + return _build_matching_attention_layers( + masked_attn_type="dao_flash", docwise_attn_type="manual", n_head_kv=n_head_kv + ) + + +def _build_matching_attention_layers( + masked_attn_type: str, docwise_attn_type: str, n_head_kv: int = 4 +) -> tuple[CausalSelfAttention, CausalSelfAttention]: + attention_config = AttentionConfig(qkv_transforms=[]) + + masked_layer = _get_random_attention_layer( + n_head_q=4, + n_head_kv=n_head_kv, + n_embd=16, + attention_impl=masked_attn_type, + attention_config=attention_config, + ) + docwise_layer = _get_random_attention_layer( + n_head_q=4, + n_head_kv=n_head_kv, + n_embd=16, + attention_impl=docwise_attn_type, + attention_config=attention_config, + ) + docwise_layer.load_state_dict(masked_layer.state_dict()) + return masked_layer, docwise_layer + + +def _get_random_input_seq(embedding_shape): + flash_attn_supported_dtype = torch.bfloat16 + return torch.rand(size=embedding_shape, dtype=flash_attn_supported_dtype).cuda() + + +def _get_random_attention_layer(n_head_q, n_head_kv, n_embd, attention_impl, attention_config): + self_attention_layer = CausalSelfAttention( + n_head_q=n_head_q, + n_head_kv=n_head_kv, + n_embd=n_embd, + bias=False, + dropout=0.0, + attention_config=attention_config, + attention_impl=attention_impl, + ).cuda() + self_attention_layer.q_attn = self_attention_layer.q_attn.bfloat16() + self_attention_layer.k_attn = self_attention_layer.k_attn.bfloat16() + self_attention_layer.v_attn = self_attention_layer.v_attn.bfloat16() + self_attention_layer.c_proj = self_attention_layer.c_proj.bfloat16() + return self_attention_layer + + +def _get_identity_attention_layer(n_head_q, n_head_kv, n_embd, attention_impl, attention_config): + self_attention_layer = CausalSelfAttention( + n_head_q=n_head_q, + n_head_kv=n_head_kv, + n_embd=n_embd, + bias=False, + dropout=0.0, + attention_config=attention_config, + attention_impl=attention_impl, + ) + with torch.no_grad(): + eye = torch.eye(n_embd, dtype=self_attention_layer.q_attn.weight.dtype) + self_attention_layer.q_attn.weight.copy_(eye) + self_attention_layer.k_attn.weight.copy_(eye) + self_attention_layer.v_attn.weight.copy_(eye) + self_attention_layer.c_proj.weight.copy_(eye) + return self_attention_layer + + +def _generate_sub_seq_lengths(total_len: int, max_chunk: int, generator: torch.Generator) -> list[int]: + lengths = [] + remaining = total_len + while remaining > 0: + next_len = int(torch.randint(1, min(max_chunk, remaining) + 1, (1,), generator=generator).item()) + lengths.append(next_len) + remaining -= next_len + return lengths + + +def _sum_lengths_per_batch(sub_seq_lengths: list[list[int]]) -> list[int]: + return [sum(lengths) for lengths in sub_seq_lengths] diff --git a/tests/models/test_gpt2_collator.py b/tests/models/test_gpt2_collator.py new file mode 100644 index 000000000..a6e1226fa --- /dev/null +++ b/tests/models/test_gpt2_collator.py @@ -0,0 +1,98 @@ +import pytest +import torch + +from modalities.models.gpt2.collator import GPT2LLMCollateFn + + +def test_gpt2_collate_shifts_samples_and_targets(): + collator = GPT2LLMCollateFn(sample_key="input_ids", target_key="labels") + batch = [ + {"input_ids": torch.tensor([1, 2, 3, 4])}, + {"input_ids": torch.tensor([5, 6, 7, 8])}, + ] + + result = collator(batch) + + assert result.samples["input_ids"].tolist() == [[1, 2, 3], [5, 6, 7]] + assert result.targets["labels"].tolist() == [[2, 3, 4], [6, 7, 8]] + + +def test_gpt2_collate_sub_seq_lengths_without_eos(): + collator = GPT2LLMCollateFn( + sample_key="input_ids", + target_key="labels", + sub_seq_lengths_key="sub_seq_lengths", + eos_token_id=99, + ) + batch = [ + {"input_ids": torch.tensor([10, 11, 12, 13, 14])}, + {"input_ids": torch.tensor([20, 21, 22, 23, 24])}, + ] + + result = collator(batch) + + assert result.samples["sub_seq_lengths"] == [[4], [4]] + + +def test_gpt2_collate_sub_seq_lengths_with_eos(): + collator = GPT2LLMCollateFn( + sample_key="input_ids", + target_key="labels", + sub_seq_lengths_key="sub_seq_lengths", + eos_token_id=99, + ) + batch = [ + {"input_ids": torch.tensor([1, 99, 2, 3, 99])}, + {"input_ids": torch.tensor([7, 8, 9, 99, 10])}, + ] + + result = collator(batch) + + assert result.samples["sub_seq_lengths"] == [[2, 2], [4]] + + +def test_gpt2_collate_sub_seq_lengths_with_eos_and_padding(): + collator = GPT2LLMCollateFn( + sample_key="input_ids", + target_key="labels", + sub_seq_lengths_key="sub_seq_lengths", + eos_token_id=99, + padding_token_id=0, + ) + batch = [ + {"input_ids": torch.tensor([1, 99, 2, 3, 4, 5])}, + {"input_ids": torch.tensor([7, 8, 99, 0, 0, 0])}, + ] + + result = collator(batch) + + assert result.samples["sub_seq_lengths"] == [[2, 3], [3]] + + +def test_gpt2_collate_sub_seq_lengths_adds_tail_when_not_padding(): + collator = GPT2LLMCollateFn( + sample_key="input_ids", + target_key="labels", + sub_seq_lengths_key="sub_seq_lengths", + eos_token_id=5, + padding_token_id=0, + ) + batch = [{"input_ids": torch.tensor([1, 5, 9, 8])}] + + result = collator(batch) + + assert result.samples["sub_seq_lengths"] == [[2, 1]] + + +def test_gpt2_collate_raises_when_sequence_starts_with_padding_and_no_eos(): + collator = GPT2LLMCollateFn( + sample_key="input_ids", + target_key="labels", + sub_seq_lengths_key="sub_seq_lengths", + eos_token_id=99, + padding_token_id=0, + ) + batch = [{"input_ids": torch.tensor([0, 1, 2, 3])}] + + with pytest.raises(AssertionError, match="Sequence starts with padding token"): + collator(batch)