Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions bsmetadata/metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class BasicMetadata:


def add_metadata_and_chunk_examples(
examples: Dict[str, List], tokenizer: PreTrainedTokenizerFast, cfg: MetadataConfig
examples: Dict[str, List], tokenizer: PreTrainedTokenizerFast, cfg: MetadataConfig, without_metadata: bool = False
) -> Dict[str, List]:
"""Adds metadata to the provided input examples, encodes them and groups them in chunks of size `cfg.max_seq_len`.

Expand Down Expand Up @@ -124,16 +124,26 @@ def is_metadata(idx: int) -> bool:
for text_chunk_encoded, chunk_metadata_mask in chunks(
max_text_len, text_with_local_metadata_encoded.input_ids, token_level_metadata_mask
):
total_len = prefix_len + len(text_chunk_encoded)
padding_len = max_text_len - len(text_chunk_encoded)

input_ids = metadata_prefix_encoded + text_chunk_encoded + [tokenizer.eos_token_id] * padding_len
attention_mask = [1] * total_len + [0] * padding_len
metadata_mask = [1] * prefix_len + [int(x) for x in chunk_metadata_mask] + [0] * padding_len

linearized_examples["input_ids"].append(input_ids)
linearized_examples["attention_mask"].append(attention_mask)
linearized_examples["metadata_mask"].append(metadata_mask)
if without_metadata:
total_len = len(text_chunk_encoded)
padding_len = cfg.max_seq_len - len(text_chunk_encoded)
attention_mask = [1] * total_len + [0] * padding_len
input_ids = text_chunk_encoded + [tokenizer.eos_token_id] * padding_len
metadata_mask = [0] * total_len
linearized_examples["input_ids"].append(input_ids)
linearized_examples["attention_mask"].append(attention_mask)
linearized_examples["metadata_mask"].append(metadata_mask)
else:
total_len = prefix_len + len(text_chunk_encoded)
padding_len = max_text_len - len(text_chunk_encoded)

input_ids = metadata_prefix_encoded + text_chunk_encoded + [tokenizer.eos_token_id] * padding_len
attention_mask = [1] * total_len + [0] * padding_len
metadata_mask = [1] * prefix_len + [int(x) for x in chunk_metadata_mask] + [0] * padding_len

linearized_examples["input_ids"].append(input_ids)
linearized_examples["attention_mask"].append(attention_mask)
linearized_examples["metadata_mask"].append(metadata_mask)

return linearized_examples

Expand Down