Skip to content
Open
15 changes: 15 additions & 0 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,21 @@ class BatchSamplerConfig(BaseModel):
class GPT2LLMCollateFnConfig(BaseModel):
sample_key: str
target_key: str
sub_seq_lengths_key: str | None = None
eos_token_id: int | None = None
padding_token_id: int | None = None

@model_validator(mode="after")
def check_sub_seq_lengths_and_eos_token(self) -> "GPT2LLMCollateFnConfig":
if (self.sub_seq_lengths_key is None) != (self.eos_token_id is None):
raise ValueError("Either both or neither of sub_seq_lengths_key and eos_token_id must be provided.")
return self

@model_validator(mode="after")
def check_padding_token_and_sub_seq_lengths(self) -> "GPT2LLMCollateFnConfig":
if self.padding_token_id is not None and self.sub_seq_lengths_key is None:
raise ValueError("If padding_token_id is provided, sub_seq_lengths_key must also be provided.")
return self


class LLMDataLoaderConfig(BaseModel):
Expand Down
56 changes: 55 additions & 1 deletion src/modalities/models/gpt2/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,31 @@
class GPT2LLMCollateFn(CollateFnIF):
"""GPT2LLMCollateFn class to define a collate function for GPT2 language model."""

def __init__(self, sample_key: str, target_key: str):
def __init__(
self,
sample_key: str,
target_key: str,
sub_seq_lengths_key: str | None = None,
eos_token_id: int | None = None,
padding_token_id: int | None = None,
):
"""
Initializes the Collator object.
If the eos token ID and the sub_seq_lengths_key are provided,
a list[list[int]] representing the sub-sequence lengths will be created.

Args:
sample_key (str): The key for accessing the sample data.
target_key (str): The key for accessing the target data.
sub_seq_lengths_key (str | None): The key for accessing the sub-sequence lengths.
eos_token_id (int | None): The end-of-sequence token ID.
padding_token_id (int | None): The padding token ID.
"""
self.sample_key = sample_key
self.target_key = target_key
self.sub_seq_lengths_key = sub_seq_lengths_key
self.eos_token_id = eos_token_id
self.padding_token_id = padding_token_id

def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch:
"""
Expand All @@ -33,4 +48,43 @@ def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch:
sample_tensor = torch.stack([torch.tensor(d[self.sample_key]) for d in batch])
samples = {self.sample_key: sample_tensor[:, :-1]}
targets = {self.target_key: sample_tensor[:, 1:]}
if self.sub_seq_lengths_key is not None:
# Determine sub sequence lengths by finding the eos tokens in each sequence in the batch.
sub_seq_lengths = self._compute_sub_sequence_lengths_for_each_sequence(samples[self.sample_key])
samples[self.sub_seq_lengths_key] = sub_seq_lengths
return DatasetBatch(targets=targets, samples=samples)

def _compute_sub_sequence_lengths_for_each_sequence(self, sample_tensor: torch.Tensor) -> list[list[int]]:
sub_seq_lengths = []
for seq in sample_tensor:
eos_positions = (seq == self.eos_token_id).nonzero(as_tuple=True)[0]
if len(eos_positions) == 0:
assert (
self.padding_token_id is None or seq[0] != self.padding_token_id
), "Sequence starts with padding token"
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion message "Sequence starts with padding token" is not very informative. It doesn't explain why this is a problem or what the user should do to fix it. Consider improving the error message to explain that sequences cannot start with padding tokens because it would result in invalid sub-sequence length computation, and suggest how to fix the data (e.g., "Invalid sequence: cannot start with padding token. Please ensure padding is only at the end of sequences after EOS tokens.").

Suggested change
), "Sequence starts with padding token"
), (
"Invalid sequence: cannot start with padding token. This prevents valid "
"sub-sequence length computation when no EOS token is present. Please ensure "
"padding is only applied at the end of sequences, typically after EOS tokens."
)

Copilot uses AI. Check for mistakes.
sub_seq_lengths.append([len(seq)])
else:
subseq_lengths = self._compute_subsequence_length(seq, eos_positions)
sub_seq_lengths.append(subseq_lengths)
return sub_seq_lengths

def _compute_subsequence_length(self, seq: torch.Tensor, eos_positions: torch.Tensor) -> list[int]:
# If the last sequence is cut, i.e. does not end on an eos token,
# it should also be included unless the padding token is set and
# the last sequence is just padding.
last_eos_pos = eos_positions[-1].item()
if self._has_cutoff_final_sequence(seq, last_eos_pos):
eos_positions = torch.cat([eos_positions, torch.tensor([len(seq) - 1])])
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When concatenating eos_positions with the last position index, a new tensor is created without specifying the device. This could cause issues if the original eos_positions tensor is on a GPU, as torch.tensor([len(seq) - 1]) will create a CPU tensor by default. Consider using eos_positions.new_tensor([len(seq) - 1]) or explicitly specifying the device to match eos_positions.

Suggested change
eos_positions = torch.cat([eos_positions, torch.tensor([len(seq) - 1])])
eos_positions = torch.cat([eos_positions, eos_positions.new_tensor([len(seq) - 1])])

Copilot uses AI. Check for mistakes.
# Compute length of each subsequence and add to lengths list.
subseq_lengths = []
prev_pos = 0
for pos in eos_positions:
subseq_lengths.append(pos.item() - prev_pos + 1)
prev_pos = pos.item() + 1
return subseq_lengths

def _has_cutoff_final_sequence(self, seq: torch.Tensor, last_eos_pos: int) -> bool:
# Assumption: If the first token of the last sequence is padding, so is the rest.
return last_eos_pos < len(seq) - 1 and (
self.padding_token_id is None or seq[last_eos_pos + 1] != self.padding_token_id
)
Loading