From c419d71aa98ff4925073b9ad6a8179e0bd89337f Mon Sep 17 00:00:00 2001 From: Ashim Mahara Date: Sun, 15 Feb 2026 15:31:59 -0500 Subject: [PATCH] included finished metadata to better track whether a trajectory is finished, this helps to set the continue_final_message which in turn fixes the apply_chat_template_error since the continue_final_message is always true in the previous tokenize_trajectory call path --- src/art/preprocessing/tokenize.py | 12 +++++++++--- src/art/trajectories.py | 1 + 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index 7d30d590..2a1c6645 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -158,6 +158,12 @@ def tokenize_trajectory( Tokenizes a trajectory and returns a TokenizedResult. """ # Find the index of the last assistant message + + # Check if the trajectory is finished + continue_final_message = ( + False if trajectory.metadata.get("finished", False) is True else True + ) + last_assistant_index = -1 for i, message in enumerate(history.messages_and_choices): if ( @@ -185,7 +191,7 @@ def tokenize_trajectory( tokenizer.apply_chat_template( cast(list[dict], messages), tools=tools, - continue_final_message=True, + continue_final_message=continue_final_message, tokenize=False, ), ) @@ -194,7 +200,7 @@ def tokenize_trajectory( tokenizer.apply_chat_template( cast(list[dict], messages), tools=tools, - continue_final_message=True, + continue_final_message=continue_final_message, ), ) sentinal_token_id = max( @@ -229,7 +235,7 @@ def tokenize_trajectory( tokenizer.apply_chat_template( cast(list[dict], token_template_messages), tools=tools, - continue_final_message=True, + continue_final_message=continue_final_message, ), ) assistant_mask: list[int] = [0] * len(token_ids) diff --git a/src/art/trajectories.py b/src/art/trajectories.py index 5a907950..481e8e48 100644 --- a/src/art/trajectories.py +++ b/src/art/trajectories.py @@ -59,6 +59,7 @@ def log(self, message: str) -> None: def finish(self) -> "Trajectory": duration = (datetime.now() - self.start_time).total_seconds() self.metrics["duration"] = duration + self.metadata["finished"] = True return self @asynccontextmanager