diff --git a/bsmetadata/metadata_utils.py b/bsmetadata/metadata_utils.py index f3f942e7..77024e10 100644 --- a/bsmetadata/metadata_utils.py +++ b/bsmetadata/metadata_utils.py @@ -48,7 +48,7 @@ class BasicMetadata: def add_metadata_and_chunk_examples( - examples: Dict[str, List], tokenizer: PreTrainedTokenizerFast, cfg: MetadataConfig + examples: Dict[str, List], tokenizer: PreTrainedTokenizerFast, cfg: MetadataConfig, without_metadata: bool = False ) -> Dict[str, List]: """Adds metadata to the provided input examples, encodes them and groups them in chunks of size `cfg.max_seq_len`. @@ -124,16 +124,26 @@ def is_metadata(idx: int) -> bool: for text_chunk_encoded, chunk_metadata_mask in chunks( max_text_len, text_with_local_metadata_encoded.input_ids, token_level_metadata_mask ): - total_len = prefix_len + len(text_chunk_encoded) - padding_len = max_text_len - len(text_chunk_encoded) - - input_ids = metadata_prefix_encoded + text_chunk_encoded + [tokenizer.eos_token_id] * padding_len - attention_mask = [1] * total_len + [0] * padding_len - metadata_mask = [1] * prefix_len + [int(x) for x in chunk_metadata_mask] + [0] * padding_len - - linearized_examples["input_ids"].append(input_ids) - linearized_examples["attention_mask"].append(attention_mask) - linearized_examples["metadata_mask"].append(metadata_mask) + if without_metadata: + total_len = len(text_chunk_encoded) + padding_len = cfg.max_seq_len - len(text_chunk_encoded) + attention_mask = [1] * total_len + [0] * padding_len + input_ids = text_chunk_encoded + [tokenizer.eos_token_id] * padding_len + metadata_mask = [0] * total_len + linearized_examples["input_ids"].append(input_ids) + linearized_examples["attention_mask"].append(attention_mask) + linearized_examples["metadata_mask"].append(metadata_mask) + else: + total_len = prefix_len + len(text_chunk_encoded) + padding_len = max_text_len - len(text_chunk_encoded) + + input_ids = metadata_prefix_encoded + text_chunk_encoded + [tokenizer.eos_token_id] * padding_len + attention_mask = [1] * total_len + [0] * padding_len + metadata_mask = [1] * prefix_len + [int(x) for x in chunk_metadata_mask] + [0] * padding_len + + linearized_examples["input_ids"].append(input_ids) + linearized_examples["attention_mask"].append(attention_mask) + linearized_examples["metadata_mask"].append(metadata_mask) return linearized_examples