Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/modalities/config/instantiation_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
from pathlib import Path
from typing import Annotated, Any, Optional
Expand Down Expand Up @@ -27,6 +28,8 @@
from modalities.util import warn_rank_0
from modalities.utils.profilers.profilers import SteppableNoProfiler

logger = logging.getLogger(__name__)


class CudaEnvSettings(BaseModel):
local_rank: Annotated[int, Field(strict=True, ge=0)]
Expand All @@ -46,6 +49,7 @@ class ConsistencyEnforcement(BaseModel):
enforce_last_step_logged: bool = True
enforce_last_step_evaluated: bool = True
enforce_last_step_checkpointed: bool = True
enforce_enough_tokens_in_dataset: bool = True


class Intervals(BaseModel):
Expand Down Expand Up @@ -192,15 +196,14 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel

@model_validator(mode="after")
def _check_token_amount_in_dataset(self) -> "TrainingComponentsInstantiationModel":
if (
len(self.train_dataset) * self.settings.step_profile.sequence_length
< self.settings.training_target.num_target_tokens
):
raise ValueError(
"Not enough tokens in the dataset. "
f"Actual: {len(self.train_dataset) * self.settings.step_profile.sequence_length}, "
f"Expected: >={self.settings.training_target.num_target_tokens}"
)
dataset_tokens = len(self.train_dataset) * self.settings.step_profile.sequence_length
expected_tokens = self.settings.training_target.num_target_tokens
if dataset_tokens < expected_tokens:
msg = f"Not enough tokens in dataset. Actual: {dataset_tokens}, Expected: >={expected_tokens}"
if self.settings.consistency_enforcement.enforce_enough_tokens_in_dataset:
raise ValueError(msg)
else:
logger.warning(msg)
return self


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import math
import os
import pickle
from itertools import repeat
from pathlib import Path
from typing import BinaryIO

Expand Down Expand Up @@ -82,30 +81,56 @@ def _write_index_segment(file_descriptor: BinaryIO, index_list: list[tuple[int,
def _write_data_segment(
file_descriptor: BinaryIO, token_data: list[np.ndarray], token_size_in_bytes: int, write_batch_size: int
) -> list[tuple[int, int]]:
def encoded_token_to_bytes(encoded_token: int, token_size_in_bytes: int) -> bytes:
# Converts an token_ids to its byte representation.
try:
token_bytes = encoded_token.to_bytes(token_size_in_bytes, byteorder="little", signed=False)
except OverflowError as e:
raise ValueError(f"Token {encoded_token} cannot be represented by {token_size_in_bytes} bytes.") from e
return token_bytes

samples = []
index_list = []
# Fast path: vectorized cast + tobytes (no per-token Python work).
# Preserves little-endian unsigned representation and overflow checks.

if token_size_in_bytes == 1:
dtype = np.dtype("u1")
elif token_size_in_bytes == 2:
dtype = np.dtype("<u2") # force little-endian
elif token_size_in_bytes == 4:
dtype = np.dtype("<u4") # force little-endian
else:
raise ValueError("Currently only support token byte sizes of 1, 2, and 4.")

max_allowed = (1 << (8 * token_size_in_bytes)) - 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the more readable version would be fast enough here. If we really need a fast one just hardcode the 3 possible values in above if-else-clause. :D

Suggested change
max_allowed = (1 << (8 * token_size_in_bytes)) - 1
max_allowed = 2 ** (8 * token_size_in_bytes) - 1


samples: list[bytes] = []
index_list: list[tuple[int, int]] = []
curr_offset = 0
pending = 0

for sample_tokens in token_data:
# convert token_ids to byte representation
sample_token_byte_string = b"".join(
map(encoded_token_to_bytes, sample_tokens.tolist(), repeat(token_size_in_bytes))
)
arr = np.asarray(sample_tokens)

# ---- Overflow / range check (preserves original semantics) ----
if arr.size:
min_val = int(arr.min())
max_val = int(arr.max())
if min_val < 0 or max_val > max_allowed:
raise ValueError(
f"Token values out of range for {token_size_in_bytes} bytes: "
f"min={min_val}, max={max_val}, allowed=[0, {max_allowed}]"
)
Comment on lines +111 to +114
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be helpful to identify the faulty token (as in the previous implementation) or even better the index in token_data (via enumerate) and the index in arr (via argmax/argmin) here.

# ----------------------------------------------------------------

# Cast to correct unsigned little-endian dtype
arr = np.asarray(arr, dtype=dtype, order="C")
sample_token_byte_string = arr.tobytes(order="C")

samples.append(sample_token_byte_string)
index_list.append((curr_offset, len(sample_token_byte_string)))
curr_offset += len(sample_token_byte_string)
if len(samples) % write_batch_size == 0:

pending += 1
if pending >= write_batch_size:
file_descriptor.write(b"".join(samples))
samples = []
samples.clear()
pending = 0

if len(samples) > 0:
file_descriptor.write(b"".join(samples))

return index_list

@staticmethod
Expand Down