From 403b9a56867d384b00b2b3d20fcb954a898e01fa Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 16 Jan 2026 23:36:50 -0800 Subject: [PATCH 1/3] Adding better exception messages --- pyrit/exceptions/__init__.py | 18 +- pyrit/exceptions/exception_context.py | 216 ++++++++++++++ pyrit/exceptions/exceptions_helpers.py | 61 +++- .../attack/multi_turn/chunked_request.py | 32 +- pyrit/executor/attack/multi_turn/crescendo.py | 80 +++-- .../attack/multi_turn/multi_prompt_sending.py | 46 ++- .../executor/attack/multi_turn/red_teaming.py | 65 +++-- .../attack/multi_turn/tree_of_attacks.py | 95 ++++-- .../attack/single_turn/prompt_sending.py | 46 ++- pyrit/executor/core/strategy.py | 31 +- pyrit/prompt_normalizer/prompt_normalizer.py | 43 ++- .../unit/exceptions/test_exception_context.py | 274 ++++++++++++++++++ .../exceptions/test_exceptions_helpers.py | 134 +++++++++ .../attack/multi_turn/test_chunked_request.py | 1 + .../attack/multi_turn/test_tree_of_attacks.py | 1 + tests/unit/executor/core/test_strategy.py | 207 +++++++++++++ .../test_prompt_normalizer.py | 124 ++++++++ 17 files changed, 1328 insertions(+), 146 deletions(-) create mode 100644 pyrit/exceptions/exception_context.py create mode 100644 tests/unit/exceptions/test_exception_context.py create mode 100644 tests/unit/executor/core/test_strategy.py diff --git a/pyrit/exceptions/__init__.py b/pyrit/exceptions/__init__.py index 9e34d4762..527de6d0c 100644 --- a/pyrit/exceptions/__init__.py +++ b/pyrit/exceptions/__init__.py @@ -15,11 +15,26 @@ pyrit_placeholder_retry, pyrit_target_retry, ) +from pyrit.exceptions.exception_context import ( + ComponentRole, + ExecutionContext, + ExecutionContextManager, + clear_execution_context, + get_execution_context, + set_execution_context, + with_execution_context, +) from pyrit.exceptions.exceptions_helpers import remove_markdown_json __all__ = [ "BadRequestException", + "clear_execution_context", + "ComponentRole", "EmptyResponseException", + "ExecutionContext", + "ExecutionContextManager", + "get_execution_context", + "get_retry_max_num_attempts", "handle_bad_request_exception", "InvalidJsonException", "MissingPromptPlaceholderException", @@ -30,5 +45,6 @@ "pyrit_placeholder_retry", "RateLimitException", "remove_markdown_json", - "get_retry_max_num_attempts", + "set_execution_context", + "with_execution_context", ] diff --git a/pyrit/exceptions/exception_context.py b/pyrit/exceptions/exception_context.py new file mode 100644 index 000000000..9ff68850d --- /dev/null +++ b/pyrit/exceptions/exception_context.py @@ -0,0 +1,216 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Context management for enhanced exception and retry logging in PyRIT. + +This module provides a contextvar-based system for tracking which component +(objective_target, adversarial_chat, objective_scorer, etc.) is currently +executing, allowing retry decorators and exception handlers to include +meaningful context in their messages. +""" + +from contextvars import ContextVar +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, Optional + + +class ComponentRole(Enum): + """ + Identifies the role of a component within an attack execution. + + This enum is used to provide meaningful context in error messages and retry logs, + helping users identify which part of an attack encountered an issue. + """ + + # Core attack components + OBJECTIVE_TARGET = "objective_target" + ADVERSARIAL_CHAT = "adversarial_chat" + + # Scoring components + OBJECTIVE_SCORER = "objective_scorer" + OBJECTIVE_SCORER_TARGET = "objective_scorer_target" + REFUSAL_SCORER = "refusal_scorer" + REFUSAL_SCORER_TARGET = "refusal_scorer_target" + AUXILIARY_SCORER = "auxiliary_scorer" + AUXILIARY_SCORER_TARGET = "auxiliary_scorer_target" + + # Conversion components + CONVERTER = "converter" + CONVERTER_TARGET = "converter_target" + + # Other components + UNKNOWN = "unknown" + + +@dataclass +class ExecutionContext: + """ + Holds context information about the currently executing component. + + This context is used to enrich error messages and retry logs with + information about which component failed and its configuration. + """ + + # The role of the component (e.g., objective_scorer, adversarial_chat) + component_role: ComponentRole = ComponentRole.UNKNOWN + + # The attack strategy class name (e.g., "PromptSendingAttack") + attack_strategy_name: Optional[str] = None + + # The identifier from the attack strategy's get_identifier() + attack_identifier: Optional[Dict[str, Any]] = None + + # The identifier from the component's get_identifier() (target, scorer, etc.) + component_identifier: Optional[Dict[str, Any]] = None + + # The objective target conversation ID if available + objective_target_conversation_id: Optional[str] = None + + # The endpoint/URI if available (extracted from component_identifier for quick access) + endpoint: Optional[str] = None + + # The component class name (extracted from component_identifier.__type__ for quick access) + component_name: Optional[str] = None + + def get_retry_context_string(self) -> str: + """ + Generate a concise context string for retry log messages. + + Returns: + str: A formatted string with component role, component name, and endpoint. + """ + parts = [self.component_role.value] + if self.component_name: + parts.append(f"({self.component_name})") + if self.endpoint: + parts.append(f"endpoint: {self.endpoint}") + return " ".join(parts) + + def get_exception_details(self) -> str: + """ + Generate detailed exception context for error messages. + + Returns: + str: A multi-line formatted string with full context details. + """ + lines = [] + + if self.attack_strategy_name: + lines.append(f"Attack: {self.attack_strategy_name}") + + lines.append(f"Component: {self.component_role.value}") + + if self.objective_target_conversation_id: + lines.append(f"Objective target conversation ID: {self.objective_target_conversation_id}") + + if self.attack_identifier: + lines.append(f"Attack identifier: {self.attack_identifier}") + + if self.component_identifier: + lines.append(f"{self.component_role.value} identifier: {self.component_identifier}") + + return "\n".join(lines) + + +# The contextvar that stores the current execution context +_execution_context: ContextVar[Optional[ExecutionContext]] = ContextVar("execution_context", default=None) + + +def get_execution_context() -> Optional[ExecutionContext]: + """ + Get the current execution context. + + Returns: + Optional[ExecutionContext]: The current context, or None if not set. + """ + return _execution_context.get() + + +def set_execution_context(context: ExecutionContext) -> None: + """ + Set the current execution context. + + Args: + context: The execution context to set. + """ + _execution_context.set(context) + + +def clear_execution_context() -> None: + """Clear the current execution context.""" + _execution_context.set(None) + + +@dataclass +class ExecutionContextManager: + """ + A context manager for setting execution context during component operations. + + This class provides a convenient way to set and automatically clear + execution context when entering and exiting a code block. + + On successful exit, the context is restored to its previous value. + On exception, the context is preserved so exception handlers can access it. + """ + + context: ExecutionContext + _token: Any = field(default=None, init=False, repr=False) + + def __enter__(self) -> "ExecutionContextManager": + """Set the execution context on entry.""" + self._token = _execution_context.set(self.context) + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """ + Restore the previous context on exit. + + If an exception occurred, the context is preserved so that exception + handlers higher in the call stack can access it for enhanced error messages. + """ + if exc_type is None: + # No exception - restore previous context + _execution_context.reset(self._token) + # On exception, leave context in place for exception handlers to read + + +def with_execution_context( + *, + component_role: ComponentRole, + attack_strategy_name: Optional[str] = None, + attack_identifier: Optional[Dict[str, Any]] = None, + component_identifier: Optional[Dict[str, Any]] = None, + objective_target_conversation_id: Optional[str] = None, +) -> ExecutionContextManager: + """ + Create an execution context manager with the specified parameters. + + Args: + component_role: The role of the component being executed. + attack_strategy_name: The name of the attack strategy class. + attack_identifier: The identifier from attack.get_identifier(). + component_identifier: The identifier from component.get_identifier(). + objective_target_conversation_id: The objective target conversation ID if available. + + Returns: + ExecutionContextManager: A context manager that sets/clears the context. + """ + # Extract endpoint and component_name from component_identifier if available + endpoint = None + component_name = None + if component_identifier: + endpoint = component_identifier.get("endpoint") + component_name = component_identifier.get("__type__") + + context = ExecutionContext( + component_role=component_role, + attack_strategy_name=attack_strategy_name, + attack_identifier=attack_identifier, + component_identifier=component_identifier, + objective_target_conversation_id=objective_target_conversation_id, + endpoint=endpoint, + component_name=component_name, + ) + return ExecutionContextManager(context=context) diff --git a/pyrit/exceptions/exceptions_helpers.py b/pyrit/exceptions/exceptions_helpers.py index e986421e9..00ac0d317 100644 --- a/pyrit/exceptions/exceptions_helpers.py +++ b/pyrit/exceptions/exceptions_helpers.py @@ -12,16 +12,57 @@ def log_exception(retry_state: RetryCallState) -> None: - # Log each retry attempt with exception details at ERROR level - elapsed_time = time.monotonic() - retry_state.start_time - call_count = retry_state.attempt_number - - if retry_state.outcome.failed: - exception = retry_state.outcome.exception() - logger.error( - f"Retry attempt {call_count} for {retry_state.fn.__name__} failed with exception: {exception}. " - f"Elapsed time: {elapsed_time} seconds. Total calls: {call_count}" - ) + """ + Log each retry attempt with exception details at ERROR level. + + If an execution context is set (via exception_context module), includes + component role and endpoint information for easier debugging. + + Args: + retry_state: The tenacity retry state containing attempt information. + """ + # Validate retry_state has required attributes before proceeding + if not retry_state: + logger.error("Retry callback invoked with no retry state") + return + + # Safely extract values with defaults + call_count = getattr(retry_state, "attempt_number", None) or 0 + start_time = getattr(retry_state, "start_time", None) + elapsed_time = (time.monotonic() - start_time) if start_time is not None else 0.0 + + outcome = getattr(retry_state, "outcome", None) + if not outcome or not getattr(outcome, "failed", False): + return + + exception = outcome.exception() if hasattr(outcome, "exception") else None + + # Get function name safely + fn = getattr(retry_state, "fn", None) + fn_name = getattr(fn, "__name__", "unknown") if fn else "unknown" + + # Build the "for X" part of the message based on execution context + for_clause = fn_name + try: + from pyrit.exceptions.exception_context import get_execution_context + + exec_context = get_execution_context() + if exec_context: + # Format: "objective scorer; TrueFalseScorer::_score_value_with_llm" + role_display = exec_context.component_role.value.replace("_", " ") + if exec_context.component_name: + for_clause = f"{role_display}. {exec_context.component_name}::{fn_name}" + else: + for_clause = f"{role_display}. {fn_name}" + except Exception: + # Don't let context retrieval errors break retry logging + pass + + logger.error( + f"Retry attempt {call_count} for {for_clause} " + f"failed with exception: {exception}. " + f"Elapsed time: {elapsed_time} seconds. Total calls: {call_count}" + ) def remove_start_md_json(response_msg: str) -> str: diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index 92b399db5..91a24ed6e 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -8,6 +8,7 @@ from typing import Any, List, Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults +from pyrit.exceptions import ComponentRole, with_execution_context from pyrit.executor.attack.component import ConversationManager from pyrit.executor.attack.core.attack_config import ( AttackConverterConfig, @@ -260,15 +261,22 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac message = Message.from_prompt(prompt=chunk_prompt, role="user") # Send the prompt using the normalizer - response = await self._prompt_normalizer.send_prompt_async( - message=message, - target=self._objective_target, - conversation_id=context.session.conversation_id, - request_converter_configurations=self._request_converters, - response_converter_configurations=self._response_converters, - labels=context.memory_labels, + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_TARGET, + attack_strategy_name=self.__class__.__name__, attack_identifier=self.get_identifier(), - ) + component_identifier=self._objective_target.get_identifier(), + objective_target_conversation_id=context.session.conversation_id, + ): + response = await self._prompt_normalizer.send_prompt_async( + message=message, + target=self._objective_target, + conversation_id=context.session.conversation_id, + request_converter_configurations=self._request_converters, + response_converter_configurations=self._response_converters, + labels=context.memory_labels, + attack_identifier=self.get_identifier(), + ) # Store the response if response: @@ -349,7 +357,13 @@ async def _score_combined_value_async( if not self._objective_scorer: return None - scores = await self._objective_scorer.score_text_async(text=combined_value, objective=objective) + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_SCORER, + attack_strategy_name=self.__class__.__name__, + attack_identifier=self.get_identifier(), + component_identifier=self._objective_scorer.get_identifier(), + ): + scores = await self._objective_scorer.score_text_async(text=combined_value, objective=objective) return scores[0] if scores else None async def _teardown_async(self, *, context: ChunkedRequestAttackContext) -> None: diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index ac937fb54..a7e618161 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -10,9 +10,11 @@ from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH from pyrit.exceptions import ( + ComponentRole, InvalidJsonException, pyrit_json_retry, remove_markdown_json, + with_execution_context, ) from pyrit.executor.attack.component import ( ConversationManager, @@ -506,13 +508,20 @@ async def _send_prompt_to_adversarial_chat_async( prompt_metadata=prompt_metadata, ) - response = await self._prompt_normalizer.send_prompt_async( - message=message, - conversation_id=context.session.adversarial_chat_conversation_id, - target=self._adversarial_chat, + with with_execution_context( + component_role=ComponentRole.ADVERSARIAL_CHAT, + attack_strategy_name=self.__class__.__name__, attack_identifier=self.get_identifier(), - labels=context.memory_labels, - ) + component_identifier=self._adversarial_chat.get_identifier(), + objective_target_conversation_id=context.session.conversation_id, + ): + response = await self._prompt_normalizer.send_prompt_async( + message=message, + conversation_id=context.session.adversarial_chat_conversation_id, + target=self._adversarial_chat, + attack_identifier=self.get_identifier(), + labels=context.memory_labels, + ) if not response: raise ValueError("No response received from adversarial chat") @@ -582,15 +591,22 @@ async def _send_prompt_to_objective_target_async( prompt_preview = attack_message.get_value()[:100] if attack_message.get_value() else "" self._logger.debug(f"Sending prompt to {objective_target_type}: {prompt_preview}...") - response = await self._prompt_normalizer.send_prompt_async( - message=attack_message, - target=self._objective_target, - conversation_id=context.session.conversation_id, - request_converter_configurations=self._request_converters, - response_converter_configurations=self._response_converters, + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_TARGET, + attack_strategy_name=self.__class__.__name__, attack_identifier=self.get_identifier(), - labels=context.memory_labels, - ) + component_identifier=self._objective_target.get_identifier(), + objective_target_conversation_id=context.session.conversation_id, + ): + response = await self._prompt_normalizer.send_prompt_async( + message=attack_message, + target=self._objective_target, + conversation_id=context.session.conversation_id, + request_converter_configurations=self._request_converters, + response_converter_configurations=self._response_converters, + attack_identifier=self.get_identifier(), + labels=context.memory_labels, + ) if not response: raise ValueError("No response received from objective target") @@ -614,9 +630,16 @@ async def _check_refusal_async(self, context: CrescendoAttackContext, objective: if not context.last_response: raise ValueError("No response available in context to check for refusal") - scores = await self._refusal_scorer.score_async( - message=context.last_response, objective=objective, skip_on_error_result=False - ) + with with_execution_context( + component_role=ComponentRole.REFUSAL_SCORER, + attack_strategy_name=self.__class__.__name__, + attack_identifier=self.get_identifier(), + component_identifier=self._refusal_scorer.get_identifier(), + objective_target_conversation_id=context.session.conversation_id, + ): + scores = await self._refusal_scorer.score_async( + message=context.last_response, objective=objective, skip_on_error_result=False + ) return scores[0] async def _score_response_async(self, *, context: CrescendoAttackContext) -> Score: @@ -636,14 +659,21 @@ async def _score_response_async(self, *, context: CrescendoAttackContext) -> Sco if not context.last_response: raise ValueError("No response available in context to score") - scoring_results = await Scorer.score_response_async( - response=context.last_response, - objective_scorer=self._objective_scorer, - auxiliary_scorers=self._auxiliary_scorers, - role_filter="assistant", - objective=context.objective, - skip_on_error_result=False, - ) + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_SCORER, + attack_strategy_name=self.__class__.__name__, + attack_identifier=self.get_identifier(), + component_identifier=self._objective_scorer.get_identifier(), + objective_target_conversation_id=context.session.conversation_id, + ): + scoring_results = await Scorer.score_response_async( + response=context.last_response, + objective_scorer=self._objective_scorer, + auxiliary_scorers=self._auxiliary_scorers, + role_filter="assistant", + objective=context.objective, + skip_on_error_result=False, + ) objective_score = scoring_results["objective_scores"] if not objective_score: diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index f89bc4482..c3bf40db4 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -7,6 +7,7 @@ from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.utils import get_kwarg_param +from pyrit.exceptions import ComponentRole, with_execution_context from pyrit.executor.attack.component import ConversationManager from pyrit.executor.attack.core.attack_config import ( AttackConverterConfig, @@ -334,15 +335,22 @@ async def _send_prompt_to_objective_target_async( Optional[Message]: The model's response if successful, or None if the request was filtered, blocked, or encountered an error. """ - return await self._prompt_normalizer.send_prompt_async( - message=current_message, - target=self._objective_target, - conversation_id=context.session.conversation_id, - request_converter_configurations=self._request_converters, - response_converter_configurations=self._response_converters, - labels=context.memory_labels, # combined with strategy labels at _setup() + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_TARGET, + attack_strategy_name=self.__class__.__name__, attack_identifier=self.get_identifier(), - ) + component_identifier=self._objective_target.get_identifier(), + objective_target_conversation_id=context.session.conversation_id, + ): + return await self._prompt_normalizer.send_prompt_async( + message=current_message, + target=self._objective_target, + conversation_id=context.session.conversation_id, + request_converter_configurations=self._request_converters, + response_converter_configurations=self._response_converters, + labels=context.memory_labels, # combined with strategy labels at _setup() + attack_identifier=self.get_identifier(), + ) async def _evaluate_response_async(self, *, response: Message, objective: str) -> Optional[Score]: """ @@ -360,14 +368,20 @@ async def _evaluate_response_async(self, *, response: Message, objective: str) - no objective scorer is set. Note that auxiliary scorer results are not returned but are still executed and stored. """ - scoring_results = await Scorer.score_response_async( - response=response, - auxiliary_scorers=self._auxiliary_scorers, - objective_scorer=self._objective_scorer if self._objective_scorer else None, - role_filter="assistant", - objective=objective, - skip_on_error_result=True, - ) + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_SCORER, + attack_strategy_name=self.__class__.__name__, + attack_identifier=self.get_identifier(), + component_identifier=self._objective_scorer.get_identifier() if self._objective_scorer else None, + ): + scoring_results = await Scorer.score_response_async( + response=response, + auxiliary_scorers=self._auxiliary_scorers, + objective_scorer=self._objective_scorer if self._objective_scorer else None, + role_filter="assistant", + objective=objective, + skip_on_error_result=True, + ) objective_scores = scoring_results["objective_scores"] if not objective_scores: diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index 886361254..a121e714a 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -11,6 +11,7 @@ from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_RED_TEAM_PATH from pyrit.common.utils import warn_if_set +from pyrit.exceptions import ComponentRole, with_execution_context from pyrit.executor.attack.component import ( ConversationManager, get_adversarial_chat_messages, @@ -361,13 +362,20 @@ async def _generate_next_prompt_async(self, context: MultiTurnAttackContext[Any] logger.debug(f"Sending prompt to adversarial chat: {prompt_text[:50]}...") prompt_message = Message.from_prompt(prompt=prompt_text, role="user") - response = await self._prompt_normalizer.send_prompt_async( - message=prompt_message, - conversation_id=context.session.adversarial_chat_conversation_id, - target=self._adversarial_chat, + with with_execution_context( + component_role=ComponentRole.ADVERSARIAL_CHAT, + attack_strategy_name=self.__class__.__name__, attack_identifier=self.get_identifier(), - labels=context.memory_labels, - ) + component_identifier=self._adversarial_chat.get_identifier(), + objective_target_conversation_id=context.session.conversation_id, + ): + response = await self._prompt_normalizer.send_prompt_async( + message=prompt_message, + conversation_id=context.session.adversarial_chat_conversation_id, + target=self._adversarial_chat, + attack_identifier=self.get_identifier(), + labels=context.memory_labels, + ) # Check if the response is valid if response is None: @@ -512,16 +520,23 @@ async def _send_prompt_to_objective_target_async( """ logger.info(f"Sending prompt to target: {message.get_value()[:50]}...") - # Send the message to the target - response = await self._prompt_normalizer.send_prompt_async( - message=message, - conversation_id=context.session.conversation_id, - request_converter_configurations=self._request_converters, - response_converter_configurations=self._response_converters, - target=self._objective_target, - labels=context.memory_labels, + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_TARGET, + attack_strategy_name=self.__class__.__name__, attack_identifier=self.get_identifier(), - ) + component_identifier=self._objective_target.get_identifier(), + objective_target_conversation_id=context.session.conversation_id, + ): + # Send the message to the target + response = await self._prompt_normalizer.send_prompt_async( + message=message, + conversation_id=context.session.conversation_id, + request_converter_configurations=self._request_converters, + response_converter_configurations=self._response_converters, + target=self._objective_target, + labels=context.memory_labels, + attack_identifier=self.get_identifier(), + ) if response is None: # Easiest way to handle this is to raise an error @@ -552,12 +567,20 @@ async def _score_response_async(self, *, context: MultiTurnAttackContext[Any]) - logger.warning("No response available in context to score") return None - # score_async handles blocked, filtered, other errors - scoring_results = await self._objective_scorer.score_async( - message=context.last_response, - role_filter="assistant", - objective=context.objective, - ) + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_SCORER, + attack_strategy_name=self.__class__.__name__, + attack_identifier=self.get_identifier(), + component_identifier=self._objective_scorer.get_identifier(), + objective_target_conversation_id=context.session.conversation_id, + ): + # score_async handles blocked, filtered, other errors + scoring_results = await self._objective_scorer.score_async( + message=context.last_response, + role_filter="assistant", + objective=context.objective, + ) + objective_scores = scoring_results return objective_scores[0] if objective_scores else None diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 9885d41f4..74c9fa86a 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -15,10 +15,12 @@ from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH from pyrit.common.utils import combine_dict from pyrit.exceptions import ( + ComponentRole, InvalidJsonException, get_retry_max_num_attempts, pyrit_json_retry, remove_markdown_json, + with_execution_context, ) from pyrit.executor.attack.component import ( ConversationManager, @@ -266,6 +268,7 @@ def __init__( response_converters: List[PromptConverterConfiguration], auxiliary_scorers: Optional[List[Scorer]], attack_id: dict[str, str], + attack_strategy_name: str, memory_labels: Optional[dict[str, str]] = None, parent_id: Optional[str] = None, prompt_normalizer: Optional[PromptNormalizer] = None, @@ -287,6 +290,7 @@ def __init__( response_converters (List[PromptConverterConfiguration]): Converters for response normalization auxiliary_scorers (Optional[List[Scorer]]): Additional scorers for the response attack_id (dict[str, str]): Unique identifier for the attack. + attack_strategy_name (str): Name of the attack strategy for execution context. memory_labels (Optional[dict[str, str]]): Labels for memory storage. parent_id (Optional[str]): ID of the parent node, if this is a child node prompt_normalizer (Optional[PromptNormalizer]): Normalizer for handling prompts and responses. @@ -306,6 +310,7 @@ def __init__( self._response_converters = response_converters self._auxiliary_scorers = auxiliary_scorers or [] self._attack_id = attack_id + self._attack_strategy_name = attack_strategy_name self._memory_labels = memory_labels or {} # Initialize utilities @@ -514,15 +519,22 @@ async def _send_prompt_to_target_async(self, prompt: str) -> Message: message = Message.from_prompt(prompt=prompt, role="user") # Send prompt with configured converters - response = await self._prompt_normalizer.send_prompt_async( - message=message, - request_converter_configurations=self._request_converters, - response_converter_configurations=self._response_converters, - conversation_id=self.objective_target_conversation_id, - target=self._objective_target, - labels=self._memory_labels, + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_TARGET, + attack_strategy_name=self._attack_strategy_name, attack_identifier=self._attack_id, - ) + component_identifier=self._objective_target.get_identifier(), + objective_target_conversation_id=self.objective_target_conversation_id, + ): + response = await self._prompt_normalizer.send_prompt_async( + message=message, + request_converter_configurations=self._request_converters, + response_converter_configurations=self._response_converters, + conversation_id=self.objective_target_conversation_id, + target=self._objective_target, + labels=self._memory_labels, + attack_identifier=self._attack_id, + ) # Store the last response text for reference response_piece = response.get_piece() @@ -562,15 +574,22 @@ async def _send_initial_prompt_to_target_async(self) -> Message: logger.debug(f"Node {self.node_id}: Using initial prompt, bypassing adversarial chat") # Send prompt with configured converters - response = await self._prompt_normalizer.send_prompt_async( - message=message, - request_converter_configurations=self._request_converters, - response_converter_configurations=self._response_converters, - conversation_id=self.objective_target_conversation_id, - target=self._objective_target, - labels=self._memory_labels, + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_TARGET, + attack_strategy_name=self._attack_strategy_name, attack_identifier=self._attack_id, - ) + component_identifier=self._objective_target.get_identifier(), + objective_target_conversation_id=self.objective_target_conversation_id, + ): + response = await self._prompt_normalizer.send_prompt_async( + message=message, + request_converter_configurations=self._request_converters, + response_converter_configurations=self._response_converters, + conversation_id=self.objective_target_conversation_id, + target=self._objective_target, + labels=self._memory_labels, + attack_identifier=self._attack_id, + ) # Store the last response text for reference response_piece = response.get_piece() @@ -608,14 +627,21 @@ async def _score_response_async(self, *, response: Message, objective: str) -> N the TAP algorithm explores in subsequent iterations. """ # Use the Scorer utility method to handle all scoring - scoring_results = await Scorer.score_response_async( - response=response, - objective_scorer=self._objective_scorer, - auxiliary_scorers=self._auxiliary_scorers, - role_filter="assistant", - objective=objective, - skip_on_error_result=False, - ) + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_SCORER, + attack_strategy_name=self._attack_strategy_name, + attack_identifier=self._attack_id, + component_identifier=self._objective_scorer.get_identifier(), + objective_target_conversation_id=self.objective_target_conversation_id, + ): + scoring_results = await Scorer.score_response_async( + response=response, + objective_scorer=self._objective_scorer, + auxiliary_scorers=self._auxiliary_scorers, + role_filter="assistant", + objective=objective, + skip_on_error_result=False, + ) # Extract objective score objective_scores = scoring_results["objective_scores"] @@ -733,6 +759,7 @@ def duplicate(self) -> "_TreeOfAttacksNode": response_converters=self._response_converters, auxiliary_scorers=self._auxiliary_scorers, attack_id=self._attack_id, + attack_strategy_name=self._attack_strategy_name, memory_labels=self._memory_labels, desired_response_prefix=self._desired_response_prefix, parent_id=self.node_id, @@ -1045,13 +1072,20 @@ async def _send_to_adversarial_chat_async(self, prompt_text: str) -> str: message.message_pieces[0].prompt_metadata = {"response_format": "json"} # Send and get response - response = await self._prompt_normalizer.send_prompt_async( - message=message, - conversation_id=self.adversarial_chat_conversation_id, - target=self._adversarial_chat, - labels=self._memory_labels, + with with_execution_context( + component_role=ComponentRole.ADVERSARIAL_CHAT, + attack_strategy_name=self._attack_strategy_name, attack_identifier=self._attack_id, - ) + component_identifier=self._adversarial_chat.get_identifier(), + objective_target_conversation_id=self.objective_target_conversation_id, + ): + response = await self._prompt_normalizer.send_prompt_async( + message=message, + conversation_id=self.adversarial_chat_conversation_id, + target=self._adversarial_chat, + labels=self._memory_labels, + attack_identifier=self._attack_id, + ) return response.get_value() @@ -1805,6 +1839,7 @@ def _create_attack_node( response_converters=self._response_converters, auxiliary_scorers=self._auxiliary_scorers, attack_id=self.get_identifier(), + attack_strategy_name=self.__class__.__name__, memory_labels=context.memory_labels, desired_response_prefix=self._desired_response_prefix, parent_id=parent_id, diff --git a/pyrit/executor/attack/single_turn/prompt_sending.py b/pyrit/executor/attack/single_turn/prompt_sending.py index 0363dc95a..9aa6f2892 100644 --- a/pyrit/executor/attack/single_turn/prompt_sending.py +++ b/pyrit/executor/attack/single_turn/prompt_sending.py @@ -7,6 +7,7 @@ from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.utils import warn_if_set +from pyrit.exceptions import ComponentRole, with_execution_context from pyrit.executor.attack.component import ConversationManager, PrependedConversationConfig from pyrit.executor.attack.core.attack_config import AttackConverterConfig, AttackScoringConfig from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT @@ -311,15 +312,22 @@ async def _send_prompt_to_objective_target_async( Optional[Message]: The model's response if successful, or None if the request was filtered, blocked, or encountered an error. """ - return await self._prompt_normalizer.send_prompt_async( - message=message, - target=self._objective_target, - conversation_id=context.conversation_id, - request_converter_configurations=self._request_converters, - response_converter_configurations=self._response_converters, - labels=context.memory_labels, # combined with strategy labels at _setup() + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_TARGET, + attack_strategy_name=self.__class__.__name__, attack_identifier=self.get_identifier(), - ) + component_identifier=self._objective_target.get_identifier(), + objective_target_conversation_id=context.conversation_id, + ): + return await self._prompt_normalizer.send_prompt_async( + message=message, + target=self._objective_target, + conversation_id=context.conversation_id, + request_converter_configurations=self._request_converters, + response_converter_configurations=self._response_converters, + labels=context.memory_labels, # combined with strategy labels at _setup() + attack_identifier=self.get_identifier(), + ) async def _evaluate_response_async( self, @@ -342,14 +350,20 @@ async def _evaluate_response_async( no objective scorer is set. Note that auxiliary scorer results are not returned but are still executed and stored. """ - scoring_results = await Scorer.score_response_async( - response=response, - objective_scorer=self._objective_scorer, - auxiliary_scorers=self._auxiliary_scorers, - role_filter="assistant", - objective=objective, - skip_on_error_result=True, - ) + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_SCORER, + attack_strategy_name=self.__class__.__name__, + attack_identifier=self.get_identifier(), + component_identifier=self._objective_scorer.get_identifier() if self._objective_scorer else None, + ): + scoring_results = await Scorer.score_response_async( + response=response, + objective_scorer=self._objective_scorer, + auxiliary_scorers=self._auxiliary_scorers, + role_filter="assistant", + objective=objective, + skip_on_error_result=True, + ) if not self._objective_scorer: return None diff --git a/pyrit/executor/core/strategy.py b/pyrit/executor/core/strategy.py index d7c59e44e..d7b9e1d22 100644 --- a/pyrit/executor/core/strategy.py +++ b/pyrit/executor/core/strategy.py @@ -16,6 +16,7 @@ from pyrit.common import default_values from pyrit.common.logger import logger +from pyrit.exceptions import clear_execution_context, get_execution_context from pyrit.models import StrategyResultT StrategyContextT = TypeVar("StrategyContextT", bound="StrategyContext") @@ -348,8 +349,34 @@ async def execute_with_context_async(self, *, context: StrategyContextT) -> Stra except Exception as e: # Notify error event await self._handle_event(event=StrategyEvent.ON_ERROR, context=context, error=e) - # Raise a specific execution error - raise RuntimeError(f"Strategy execution failed for {self.__class__.__name__}: {str(e)}") from e + + # Build enhanced error message with execution context if available + # Note: The context is preserved on exception by ExecutionContextManager + exec_context = get_execution_context() + if exec_context: + error_details = exec_context.get_exception_details() + + # Extract the root cause exception for better diagnostics + root_cause = e + while root_cause.__cause__ is not None: + root_cause = root_cause.__cause__ + + # Include root cause type and message if different from the immediate exception + if root_cause is not e: + root_cause_info = f"\n\nRoot cause: {type(root_cause).__name__}: {str(root_cause)}" + else: + root_cause_info = "" + + error_message = ( + f"Strategy execution failed for {exec_context.component_role.value} " + f"in {self.__class__.__name__}: {str(e)}{root_cause_info}\n\nDetails:\n{error_details}" + ) + # Clear the context now that we've read it + clear_execution_context() + else: + error_message = f"Strategy execution failed for {self.__class__.__name__}: {str(e)}" + + raise RuntimeError(error_message) from e async def execute_async(self, **kwargs: Any) -> StrategyResultT: """ diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index 753b9220c..55a839a6a 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -8,7 +8,12 @@ from typing import Any, List, Optional from uuid import uuid4 -from pyrit.exceptions import EmptyResponseException, PyritException +from pyrit.exceptions import ( + ComponentRole, + EmptyResponseException, + get_execution_context, + with_execution_context, +) from pyrit.memory import CentralMemory, MemoryInterface from pyrit.models import ( Message, @@ -211,8 +216,7 @@ async def convert_values( message (Message): The message containing pieces to be converted. Raises: - PyritException: If a converter raises a PyRIT exception (re-raised with enhanced context). - RuntimeError: If a converter raises a non-PyRIT exception (wrapped with converter context). + Exception: Any exception from converters propagates with execution context for error tracing. """ for converter_configuration in converter_configurations: for piece_index, piece in enumerate(message.message_pieces): @@ -232,23 +236,30 @@ async def convert_values( converted_text_data_type = piece.converted_value_data_type for converter in converter_configuration.converters: + # Inherit attack context from outer execution context (set by attack strategy) + outer_context = get_execution_context() + try: - converter_result = await converter.convert_tokens_async( - prompt=converted_text, - input_type=converted_text_data_type, - start_token=self._start_token, - end_token=self._end_token, - ) + with with_execution_context( + component_role=ComponentRole.CONVERTER, + attack_strategy_name=outer_context.attack_strategy_name if outer_context else None, + attack_identifier=outer_context.attack_identifier if outer_context else None, + component_identifier=converter.get_identifier(), + objective_target_conversation_id=( + outer_context.objective_target_conversation_id if outer_context else None + ), + ): + converter_result = await converter.convert_tokens_async( + prompt=converted_text, + input_type=converted_text_data_type, + start_token=self._start_token, + end_token=self._end_token, + ) converted_text = converter_result.output_text converted_text_data_type = converter_result.output_type - except PyritException as e: - # Re-raise PyRIT exceptions with enhanced context while preserving type for retry decorators - e.message = f"Error in converter {converter.__class__.__name__}: {e.message}" - e.args = (f"Status Code: {e.status_code}, Message: {e.message}",) + except Exception: + # Let the exception propagate - execution context will add converter details raise - except Exception as e: - # Wrap non-PyRIT exceptions for better error tracing - raise RuntimeError(f"Error in converter {converter.__class__.__name__}: {str(e)}") from e piece.converted_value = converted_text piece.converted_value_data_type = converted_text_data_type diff --git a/tests/unit/exceptions/test_exception_context.py b/tests/unit/exceptions/test_exception_context.py new file mode 100644 index 000000000..99c4b4c7b --- /dev/null +++ b/tests/unit/exceptions/test_exception_context.py @@ -0,0 +1,274 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest + +from pyrit.exceptions import ( + ComponentRole, + ExecutionContext, + ExecutionContextManager, + clear_execution_context, + get_execution_context, + set_execution_context, + with_execution_context, +) + + +class TestExecutionContext: + """Tests for the ExecutionContext dataclass.""" + + def test_default_values(self): + """Test that ExecutionContext has correct default values.""" + context = ExecutionContext() + assert context.component_role == ComponentRole.UNKNOWN + assert context.attack_strategy_name is None + assert context.attack_identifier is None + assert context.component_identifier is None + assert context.objective_target_conversation_id is None + assert context.endpoint is None + assert context.component_name is None + + def test_initialization_with_values(self): + """Test ExecutionContext initialization with all values.""" + context = ExecutionContext( + component_role=ComponentRole.OBJECTIVE_TARGET, + attack_strategy_name="PromptSendingAttack", + attack_identifier={"__type__": "PromptSendingAttack", "id": "abc123"}, + component_identifier={"__type__": "OpenAIChatTarget", "endpoint": "https://api.openai.com"}, + objective_target_conversation_id="conv-123", + endpoint="https://api.openai.com", + component_name="OpenAIChatTarget", + ) + assert context.component_role == ComponentRole.OBJECTIVE_TARGET + assert context.attack_strategy_name == "PromptSendingAttack" + assert context.attack_identifier == {"__type__": "PromptSendingAttack", "id": "abc123"} + assert context.component_identifier == {"__type__": "OpenAIChatTarget", "endpoint": "https://api.openai.com"} + assert context.objective_target_conversation_id == "conv-123" + assert context.endpoint == "https://api.openai.com" + assert context.component_name == "OpenAIChatTarget" + + def test_get_retry_context_string_minimal(self): + """Test retry context string with only component role.""" + context = ExecutionContext(component_role=ComponentRole.OBJECTIVE_SCORER) + result = context.get_retry_context_string() + assert result == "objective_scorer" + + def test_get_retry_context_string_with_component_name(self): + """Test retry context string includes component name when set.""" + context = ExecutionContext( + component_role=ComponentRole.OBJECTIVE_SCORER, + component_name="TrueFalseScorer", + ) + result = context.get_retry_context_string() + assert "objective_scorer" in result + assert "(TrueFalseScorer)" in result + + def test_get_retry_context_string_with_endpoint(self): + """Test retry context string includes endpoint when set.""" + context = ExecutionContext( + component_role=ComponentRole.ADVERSARIAL_CHAT, + endpoint="https://api.example.com", + ) + result = context.get_retry_context_string() + assert "adversarial_chat" in result + assert "endpoint: https://api.example.com" in result + + def test_get_retry_context_string_full(self): + """Test retry context string with component name and endpoint.""" + context = ExecutionContext( + component_role=ComponentRole.OBJECTIVE_TARGET, + component_name="OpenAIChatTarget", + endpoint="https://api.openai.com", + ) + result = context.get_retry_context_string() + assert "objective_target" in result + assert "(OpenAIChatTarget)" in result + assert "endpoint: https://api.openai.com" in result + + def test_get_exception_details_minimal(self): + """Test exception details with minimal context.""" + context = ExecutionContext(component_role=ComponentRole.CONVERTER) + result = context.get_exception_details() + assert "Component: converter" in result + + def test_get_exception_details_full(self): + """Test exception details with full context.""" + context = ExecutionContext( + component_role=ComponentRole.OBJECTIVE_TARGET, + attack_strategy_name="RedTeamingAttack", + attack_identifier={"__type__": "RedTeamingAttack", "id": "xyz"}, + component_identifier={"__type__": "OpenAIChatTarget"}, + objective_target_conversation_id="conv-456", + ) + result = context.get_exception_details() + assert "Attack: RedTeamingAttack" in result + assert "Component: objective_target" in result + assert "Objective target conversation ID: conv-456" in result + assert "Attack identifier:" in result + assert "objective_target identifier:" in result + + +class TestExecutionContextFunctions: + """Tests for the context management functions.""" + + def teardown_method(self): + """Clear context after each test.""" + clear_execution_context() + + def test_get_execution_context_default(self): + """Test that get_execution_context returns None when not set.""" + clear_execution_context() + assert get_execution_context() is None + + def test_set_and_get_execution_context(self): + """Test setting and getting execution context.""" + context = ExecutionContext(component_role=ComponentRole.OBJECTIVE_SCORER) + set_execution_context(context) + retrieved = get_execution_context() + assert retrieved is context + assert retrieved.component_role == ComponentRole.OBJECTIVE_SCORER + + def test_clear_execution_context(self): + """Test clearing execution context.""" + context = ExecutionContext(component_role=ComponentRole.ADVERSARIAL_CHAT) + set_execution_context(context) + assert get_execution_context() is not None + clear_execution_context() + assert get_execution_context() is None + + def test_context_isolation(self): + """Test that setting a new context replaces the old one.""" + context1 = ExecutionContext(component_role=ComponentRole.OBJECTIVE_TARGET) + context2 = ExecutionContext(component_role=ComponentRole.ADVERSARIAL_CHAT) + + set_execution_context(context1) + assert get_execution_context().component_role == ComponentRole.OBJECTIVE_TARGET + + set_execution_context(context2) + assert get_execution_context().component_role == ComponentRole.ADVERSARIAL_CHAT + + +class TestExecutionContextManager: + """Tests for the ExecutionContextManager class.""" + + def teardown_method(self): + """Clear context after each test.""" + clear_execution_context() + + def test_context_manager_sets_and_clears_context(self): + """Test that context manager sets context on enter and clears on exit.""" + context = ExecutionContext(component_role=ComponentRole.REFUSAL_SCORER) + + assert get_execution_context() is None + + with ExecutionContextManager(context=context): + assert get_execution_context() is context + + # Context should be cleared after successful exit + assert get_execution_context() is None + + def test_context_manager_preserves_context_on_exception(self): + """Test that context is preserved when an exception occurs.""" + context = ExecutionContext(component_role=ComponentRole.OBJECTIVE_TARGET) + + with pytest.raises(ValueError): + with ExecutionContextManager(context=context): + assert get_execution_context() is context + raise ValueError("Test error") + + # Context should still be set after exception + assert get_execution_context() is context + + def test_context_manager_nested(self): + """Test nested context managers.""" + outer_context = ExecutionContext(component_role=ComponentRole.OBJECTIVE_TARGET) + inner_context = ExecutionContext(component_role=ComponentRole.CONVERTER) + + with ExecutionContextManager(context=outer_context): + assert get_execution_context().component_role == ComponentRole.OBJECTIVE_TARGET + + with ExecutionContextManager(context=inner_context): + assert get_execution_context().component_role == ComponentRole.CONVERTER + + # After inner exits, outer should be restored + assert get_execution_context().component_role == ComponentRole.OBJECTIVE_TARGET + + # After outer exits, should be None + assert get_execution_context() is None + + +class TestWithExecutionContext: + """Tests for the with_execution_context factory function.""" + + def teardown_method(self): + """Clear context after each test.""" + clear_execution_context() + + def test_with_execution_context_creates_manager(self): + """Test that with_execution_context creates a proper context manager.""" + manager = with_execution_context( + component_role=ComponentRole.OBJECTIVE_TARGET, + attack_strategy_name="TestAttack", + ) + assert isinstance(manager, ExecutionContextManager) + assert manager.context.component_role == ComponentRole.OBJECTIVE_TARGET + assert manager.context.attack_strategy_name == "TestAttack" + + def test_with_execution_context_extracts_endpoint(self): + """Test that endpoint is extracted from component_identifier.""" + component_id = {"__type__": "OpenAIChatTarget", "endpoint": "https://api.openai.com"} + manager = with_execution_context( + component_role=ComponentRole.OBJECTIVE_TARGET, + component_identifier=component_id, + ) + assert manager.context.endpoint == "https://api.openai.com" + + def test_with_execution_context_extracts_component_name(self): + """Test that component_name is extracted from component_identifier.__type__.""" + component_id = {"__type__": "TrueFalseScorer", "endpoint": "https://api.openai.com"} + manager = with_execution_context( + component_role=ComponentRole.OBJECTIVE_SCORER, + component_identifier=component_id, + ) + assert manager.context.component_name == "TrueFalseScorer" + + def test_with_execution_context_no_endpoint(self): + """Test that endpoint is None when not in component_identifier.""" + component_id = {"__type__": "TextTarget"} + manager = with_execution_context( + component_role=ComponentRole.OBJECTIVE_TARGET, + component_identifier=component_id, + ) + assert manager.context.endpoint is None + + def test_with_execution_context_full_usage(self): + """Test full usage of with_execution_context as context manager.""" + with with_execution_context( + component_role=ComponentRole.ADVERSARIAL_CHAT, + attack_strategy_name="CrescendoAttack", + attack_identifier={"id": "test"}, + component_identifier={"endpoint": "https://example.com"}, + objective_target_conversation_id="conv-789", + ): + ctx = get_execution_context() + assert ctx is not None + assert ctx.component_role == ComponentRole.ADVERSARIAL_CHAT + assert ctx.attack_strategy_name == "CrescendoAttack" + assert ctx.objective_target_conversation_id == "conv-789" + assert ctx.endpoint == "https://example.com" + + assert get_execution_context() is None + + def test_with_execution_context_preserves_on_exception(self): + """Test that context is preserved on exception for error handling.""" + with pytest.raises(RuntimeError): + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_SCORER, + attack_strategy_name="TestAttack", + ): + raise RuntimeError("Scorer failed") + + # Context should still be available for exception handlers + ctx = get_execution_context() + assert ctx is not None + assert ctx.component_role == ComponentRole.OBJECTIVE_SCORER diff --git a/tests/unit/exceptions/test_exceptions_helpers.py b/tests/unit/exceptions/test_exceptions_helpers.py index d8e7762f7..b576b1c87 100644 --- a/tests/unit/exceptions/test_exceptions_helpers.py +++ b/tests/unit/exceptions/test_exceptions_helpers.py @@ -8,8 +8,20 @@ remove_end_md_json, remove_markdown_json, remove_start_md_json, + log_exception ) +from pyrit.exceptions import ( + ComponentRole, + ExecutionContext, + clear_execution_context, + set_execution_context, +) + +# Tests for log_exception with execution context +from concurrent.futures import Future +from unittest.mock import MagicMock, patch + @pytest.mark.parametrize( "input_str, expected_output", @@ -69,3 +81,125 @@ def test_extract_json_from_string(input_str, expected_output): ) def test_remove_markdown_json(input_str, expected_output): assert remove_markdown_json(input_str) == expected_output + + + + + +class TestLogException: + """Tests for the log_exception function with execution context.""" + + def teardown_method(self): + """Clear context after each test.""" + clear_execution_context() + + def test_log_exception_without_context(self): + """Test log_exception works when no execution context is set.""" + # Create a mock retry state + retry_state = MagicMock() + retry_state.attempt_number = 2 + retry_state.start_time = 0.0 + + # Create a failed outcome with an exception + outcome = MagicMock() + outcome.failed = True + outcome.exception.return_value = ValueError("Test error") + retry_state.outcome = outcome + + retry_state.fn = MagicMock() + retry_state.fn.__name__ = "test_function" + + with patch("pyrit.exceptions.exceptions_helpers.logger") as mock_logger: + log_exception(retry_state) + mock_logger.error.assert_called_once() + call_args = mock_logger.error.call_args[0][0] + assert "test_function" in call_args + assert "Test error" in call_args + # Should just have function name when no context is set + assert "objective target" not in call_args + + def test_log_exception_with_context_and_component_name(self): + """Test log_exception includes component role and class name when set.""" + # Set execution context with component name + context = ExecutionContext( + component_role=ComponentRole.OBJECTIVE_SCORER, + component_name="TrueFalseScorer", + endpoint="https://api.openai.com", + ) + set_execution_context(context) + + # Create a mock retry state + retry_state = MagicMock() + retry_state.attempt_number = 3 + retry_state.start_time = 0.0 + + outcome = MagicMock() + outcome.failed = True + outcome.exception.return_value = ConnectionError("Connection failed") + retry_state.outcome = outcome + + retry_state.fn = MagicMock() + retry_state.fn.__name__ = "_score_value_with_llm" + + with patch("pyrit.exceptions.exceptions_helpers.logger") as mock_logger: + log_exception(retry_state) + mock_logger.error.assert_called_once() + call_args = mock_logger.error.call_args[0][0] + # New format: "objective scorer; TrueFalseScorer::_score_value_with_llm" + assert "objective scorer" in call_args + assert "TrueFalseScorer::_score_value_with_llm" in call_args + assert "Connection failed" in call_args + + def test_log_exception_with_context_no_component_name(self): + """Test log_exception with context but no component name.""" + context = ExecutionContext(component_role=ComponentRole.CONVERTER) + set_execution_context(context) + + retry_state = MagicMock() + retry_state.attempt_number = 1 + retry_state.start_time = 0.0 + + outcome = MagicMock() + outcome.failed = True + outcome.exception.return_value = RuntimeError("Conversion failed") + retry_state.outcome = outcome + + retry_state.fn = MagicMock() + retry_state.fn.__name__ = "convert_async" + + with patch("pyrit.exceptions.exceptions_helpers.logger") as mock_logger: + log_exception(retry_state) + call_args = mock_logger.error.call_args[0][0] + # Without component name: "converter; convert_async" + assert "converter. convert_async" in call_args + # Should not have "::" since no component name + assert "::" not in call_args + + def test_log_exception_no_retry_state(self): + """Test log_exception handles None retry_state gracefully.""" + with patch("pyrit.exceptions.exceptions_helpers.logger") as mock_logger: + log_exception(None) + mock_logger.error.assert_called_once() + assert "no retry state" in mock_logger.error.call_args[0][0].lower() + + def test_log_exception_no_outcome(self): + """Test log_exception handles missing outcome gracefully.""" + retry_state = MagicMock() + retry_state.outcome = None + + with patch("pyrit.exceptions.exceptions_helpers.logger") as mock_logger: + log_exception(retry_state) + # Should return early without logging error details + mock_logger.error.assert_not_called() + + def test_log_exception_outcome_not_failed(self): + """Test log_exception doesn't log when outcome is not failed.""" + retry_state = MagicMock() + retry_state.attempt_number = 1 + outcome = MagicMock() + outcome.failed = False + retry_state.outcome = outcome + + with patch("pyrit.exceptions.exceptions_helpers.logger") as mock_logger: + log_exception(retry_state) + mock_logger.error.assert_not_called() diff --git a/tests/unit/executor/attack/multi_turn/test_chunked_request.py b/tests/unit/executor/attack/multi_turn/test_chunked_request.py index 2d6ee807e..b3bf1e894 100644 --- a/tests/unit/executor/attack/multi_turn/test_chunked_request.py +++ b/tests/unit/executor/attack/multi_turn/test_chunked_request.py @@ -40,6 +40,7 @@ def test_context_with_chunk_responses(self): assert context.chunk_responses == ["abc", "def", "ghi"] +@pytest.mark.usefixtures("patch_central_database") class TestChunkedRequestAttack: """Test the ChunkedRequestAttack class.""" diff --git a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py index ed0484c1c..4e04abeb0 100644 --- a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py @@ -1127,6 +1127,7 @@ def node_components(self, attack_builder): "response_converters": [], "auxiliary_scorers": [], "attack_id": {"id": "test_attack"}, + "attack_strategy_name": "TreeOfAttacksWithPruningAttack", "memory_labels": {"test": "label"}, "parent_id": None, "prompt_normalizer": prompt_normalizer, diff --git a/tests/unit/executor/core/test_strategy.py b/tests/unit/executor/core/test_strategy.py new file mode 100644 index 000000000..ae3a58ec5 --- /dev/null +++ b/tests/unit/executor/core/test_strategy.py @@ -0,0 +1,207 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.exceptions import ( + ComponentRole, + clear_execution_context, + get_execution_context, + with_execution_context, +) +from pyrit.executor.core.strategy import Strategy, StrategyContext + + +@dataclass +class MockContext(StrategyContext): + """A mock context for testing.""" + + value: str = "test" + + +class MockStrategy(Strategy[MockContext, str]): + """A mock strategy for testing.""" + + def __init__(self, perform_result: str = "success", perform_exception: Exception = None): + # Initialize base class with the context type + super().__init__(context_type=MockContext) + self._perform_result = perform_result + self._perform_exception = perform_exception + + async def _setup_async(self, *, context: MockContext) -> None: + pass + + async def _perform_async(self, *, context: MockContext) -> str: + if self._perform_exception: + raise self._perform_exception + return self._perform_result + + async def _teardown_async(self, *, context: MockContext) -> None: + pass + + def _validate_context(self, *, context: MockContext) -> None: + pass + + +@pytest.mark.usefixtures("patch_central_database") +class TestStrategyExecutionContext: + """Tests for Strategy execution context handling.""" + + def teardown_method(self): + """Clear context after each test.""" + clear_execution_context() + + @pytest.mark.asyncio + async def test_execute_with_context_success_clears_context(self): + """Test that successful execution clears execution context.""" + strategy = MockStrategy(perform_result="success") + context = MockContext() + + # Set a context before execution + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_TARGET, + attack_strategy_name="TestStrategy", + ): + result = await strategy.execute_with_context_async(context=context) + + assert result == "success" + # Context should be cleared after successful execution + # (cleared by the context manager on successful exit) + + @pytest.mark.asyncio + async def test_execute_with_context_exception_includes_context(self): + """Test that exceptions include execution context details.""" + strategy = MockStrategy(perform_exception=ValueError("Test error")) + context = MockContext() + + # The strategy wraps the exception with context details + with pytest.raises(RuntimeError) as exc_info: + # First set an execution context (simulating what attacks do) + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_TARGET, + attack_strategy_name="MockStrategy", + attack_identifier={"id": "test-123"}, + ): + await strategy.execute_with_context_async(context=context) + + error_message = str(exc_info.value) + # Should include component role + assert "objective_target" in error_message + # Should include strategy name + assert "MockStrategy" in error_message + + @pytest.mark.asyncio + async def test_execute_with_context_exception_without_context(self): + """Test that exceptions work even without execution context.""" + strategy = MockStrategy(perform_exception=RuntimeError("Something went wrong")) + context = MockContext() + + with pytest.raises(RuntimeError) as exc_info: + await strategy.execute_with_context_async(context=context) + + error_message = str(exc_info.value) + assert "MockStrategy" in error_message + assert "Something went wrong" in error_message + + @pytest.mark.asyncio + async def test_execute_with_context_preserves_root_cause(self): + """Test that the original exception is preserved as __cause__.""" + original_error = ValueError("Original error") + strategy = MockStrategy(perform_exception=original_error) + context = MockContext() + + with pytest.raises(RuntimeError) as exc_info: + await strategy.execute_with_context_async(context=context) + + # The __cause__ should be the original exception + assert exc_info.value.__cause__ is original_error + + @pytest.mark.asyncio + async def test_execute_with_context_extracts_root_cause(self): + """Test that chained exceptions show root cause in error message.""" + # Create a chain of exceptions + root_cause = ConnectionError("Connection refused") + middle_error = RuntimeError("API call failed") + middle_error.__cause__ = root_cause + + strategy = MockStrategy(perform_exception=middle_error) + context = MockContext() + + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_TARGET, + attack_strategy_name="MockStrategy", + ): + with pytest.raises(RuntimeError) as exc_info: + await strategy.execute_with_context_async(context=context) + + error_message = str(exc_info.value) + # Should include root cause information + assert "Root cause" in error_message + assert "ConnectionError" in error_message + assert "Connection refused" in error_message + + +@pytest.mark.usefixtures("patch_central_database") +class TestStrategyExecutionContextDetails: + """Tests for execution context detail extraction in strategy errors.""" + + def teardown_method(self): + """Clear context after each test.""" + clear_execution_context() + + @pytest.mark.asyncio + async def test_error_includes_attack_identifier(self): + """Test that error message includes attack identifier.""" + strategy = MockStrategy(perform_exception=ValueError("Error")) + context = MockContext() + + with pytest.raises(RuntimeError) as exc_info: + with with_execution_context( + component_role=ComponentRole.ADVERSARIAL_CHAT, + attack_strategy_name="TestAttack", + attack_identifier={"__type__": "TestAttack", "id": "abc-123"}, + ): + await strategy.execute_with_context_async(context=context) + + error_message = str(exc_info.value) + assert "Attack identifier:" in error_message + assert "abc-123" in error_message + + @pytest.mark.asyncio + async def test_error_includes_conversation_id(self): + """Test that error message includes objective target conversation ID.""" + strategy = MockStrategy(perform_exception=ValueError("Error")) + context = MockContext() + + with pytest.raises(RuntimeError) as exc_info: + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_TARGET, + attack_strategy_name="TestAttack", + objective_target_conversation_id="conv-xyz-789", + ): + await strategy.execute_with_context_async(context=context) + + error_message = str(exc_info.value) + assert "Objective target conversation ID: conv-xyz-789" in error_message + + @pytest.mark.asyncio + async def test_error_includes_component_identifier(self): + """Test that error message includes component identifier.""" + strategy = MockStrategy(perform_exception=ValueError("Error")) + context = MockContext() + + with pytest.raises(RuntimeError) as exc_info: + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_SCORER, + attack_strategy_name="TestAttack", + component_identifier={"__type__": "SelfAskTrueFalseScorer"}, + ): + await strategy.execute_with_context_async(context=context) + + error_message = str(exc_info.value) + assert "objective_scorer identifier:" in error_message + assert "SelfAskTrueFalseScorer" in error_message diff --git a/tests/unit/prompt_normalizer/test_prompt_normalizer.py b/tests/unit/prompt_normalizer/test_prompt_normalizer.py index 07319440b..b3d5b3f64 100644 --- a/tests/unit/prompt_normalizer/test_prompt_normalizer.py +++ b/tests/unit/prompt_normalizer/test_prompt_normalizer.py @@ -24,6 +24,12 @@ PromptConverter, StringJoinConverter, ) +from pyrit.exceptions import ( + ComponentRole, + clear_execution_context, + get_execution_context, + with_execution_context, +) from pyrit.prompt_normalizer import NormalizerRequest, PromptNormalizer from pyrit.prompt_normalizer.prompt_converter_configuration import ( PromptConverterConfiguration, @@ -428,3 +434,121 @@ async def test_send_prompt_async_exception_conv_id(mock_memory_instance, seed_gr "Test Exception" in mock_memory_instance.add_message_to_memory.call_args_list[1][1]["request"].message_pieces[0].original_value ) + + +# Tests for execution context in converter operations (used for error message handling) + +class ContextCapturingConverter(PromptConverter): + """A converter that captures the execution context during conversion.""" + + SUPPORTED_INPUT_TYPES: tuple[PromptDataType, ...] = ("text",) + SUPPORTED_OUTPUT_TYPES: tuple[PromptDataType, ...] = ("text",) + captured_context = None + + def __init__(self) -> None: + pass + + async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: + # Capture the current execution context + ContextCapturingConverter.captured_context = get_execution_context() + return ConverterResult(output_text=f"converted:{prompt}", output_type="text") + + def input_supported(self, input_type: PromptDataType) -> bool: + return input_type == "text" + + def output_supported(self, output_type: PromptDataType) -> bool: + return output_type == "text" + + +class FailingConverter(PromptConverter): + """A converter that raises an exception during conversion.""" + + SUPPORTED_INPUT_TYPES: tuple[PromptDataType, ...] = ("text",) + SUPPORTED_OUTPUT_TYPES: tuple[PromptDataType, ...] = ("text",) + + def __init__(self) -> None: + pass + + async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: + raise RuntimeError("Converter failed") + + def input_supported(self, input_type: PromptDataType) -> bool: + return input_type == "text" + + def output_supported(self, output_type: PromptDataType) -> bool: + return output_type == "text" + + +class TestPromptNormalizerConverterContext: + """Tests for execution context during converter operations in PromptNormalizer.""" + + def teardown_method(self): + """Clear context after each test.""" + clear_execution_context() + ContextCapturingConverter.captured_context = None + + @pytest.mark.asyncio + async def test_convert_values_sets_converter_context(self, mock_memory_instance): + """Test that convert_values sets CONVERTER execution context.""" + normalizer = PromptNormalizer() + message = Message.from_prompt(prompt="test", role="user") + + converter_config = PromptConverterConfiguration(converters=[ContextCapturingConverter()]) + + await normalizer.convert_values(converter_configurations=[converter_config], message=message) + + # The converter should have captured the execution context + captured = ContextCapturingConverter.captured_context + assert captured is not None + assert captured.component_role == ComponentRole.CONVERTER + + @pytest.mark.asyncio + async def test_convert_values_inherits_outer_context(self, mock_memory_instance): + """Test that converter context inherits attack info from outer context.""" + normalizer = PromptNormalizer() + message = Message.from_prompt(prompt="test", role="user") + + converter_config = PromptConverterConfiguration(converters=[ContextCapturingConverter()]) + + # Set an outer execution context (simulating being called from an attack) + with with_execution_context( + component_role=ComponentRole.OBJECTIVE_TARGET, + attack_strategy_name="TestAttack", + attack_identifier={"id": "attack-123"}, + objective_target_conversation_id="conv-456", + ): + await normalizer.convert_values(converter_configurations=[converter_config], message=message) + + # The converter should have captured the context with inherited values + captured = ContextCapturingConverter.captured_context + assert captured is not None + assert captured.component_role == ComponentRole.CONVERTER + assert captured.attack_strategy_name == "TestAttack" + assert captured.objective_target_conversation_id == "conv-456" + + @pytest.mark.asyncio + async def test_convert_values_exception_propagates(self, mock_memory_instance): + """Test that converter exceptions propagate correctly.""" + normalizer = PromptNormalizer() + message = Message.from_prompt(prompt="test", role="user") + + converter_config = PromptConverterConfiguration(converters=[FailingConverter()]) + + with pytest.raises(RuntimeError, match="Converter failed"): + await normalizer.convert_values(converter_configurations=[converter_config], message=message) + + @pytest.mark.asyncio + async def test_convert_values_context_includes_converter_identifier(self, mock_memory_instance): + """Test that converter context includes the converter's identifier.""" + normalizer = PromptNormalizer() + message = Message.from_prompt(prompt="test", role="user") + + converter = ContextCapturingConverter() + converter_config = PromptConverterConfiguration(converters=[converter]) + + await normalizer.convert_values(converter_configurations=[converter_config], message=message) + + captured = ContextCapturingConverter.captured_context + assert captured is not None + assert captured.component_identifier is not None + assert "ContextCapturingConverter" in str(captured.component_identifier) From d48ee9f15a66a513fb97d775c1f7a2d4b5391a6e Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 16 Jan 2026 23:41:06 -0800 Subject: [PATCH 2/3] pre-commit --- pyrit/executor/core/strategy.py | 2 +- .../exceptions/test_exceptions_helpers.py | 25 ++++++++----------- tests/unit/executor/core/test_strategy.py | 3 --- .../test_prompt_normalizer.py | 15 +++++------ 4 files changed, 19 insertions(+), 26 deletions(-) diff --git a/pyrit/executor/core/strategy.py b/pyrit/executor/core/strategy.py index d7b9e1d22..7fc48a417 100644 --- a/pyrit/executor/core/strategy.py +++ b/pyrit/executor/core/strategy.py @@ -357,7 +357,7 @@ async def execute_with_context_async(self, *, context: StrategyContextT) -> Stra error_details = exec_context.get_exception_details() # Extract the root cause exception for better diagnostics - root_cause = e + root_cause: BaseException = e while root_cause.__cause__ is not None: root_cause = root_cause.__cause__ diff --git a/tests/unit/exceptions/test_exceptions_helpers.py b/tests/unit/exceptions/test_exceptions_helpers.py index b576b1c87..c8ddccb53 100644 --- a/tests/unit/exceptions/test_exceptions_helpers.py +++ b/tests/unit/exceptions/test_exceptions_helpers.py @@ -1,15 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import pytest +# Tests for log_exception with execution context +from unittest.mock import MagicMock, patch -from pyrit.exceptions.exceptions_helpers import ( - extract_json_from_string, - remove_end_md_json, - remove_markdown_json, - remove_start_md_json, - log_exception -) +import pytest from pyrit.exceptions import ( ComponentRole, @@ -17,10 +12,13 @@ clear_execution_context, set_execution_context, ) - -# Tests for log_exception with execution context -from concurrent.futures import Future -from unittest.mock import MagicMock, patch +from pyrit.exceptions.exceptions_helpers import ( + extract_json_from_string, + log_exception, + remove_end_md_json, + remove_markdown_json, + remove_start_md_json, +) @pytest.mark.parametrize( @@ -83,9 +81,6 @@ def test_remove_markdown_json(input_str, expected_output): assert remove_markdown_json(input_str) == expected_output - - - class TestLogException: """Tests for the log_exception function with execution context.""" diff --git a/tests/unit/executor/core/test_strategy.py b/tests/unit/executor/core/test_strategy.py index ae3a58ec5..28636436a 100644 --- a/tests/unit/executor/core/test_strategy.py +++ b/tests/unit/executor/core/test_strategy.py @@ -2,15 +2,12 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch import pytest from pyrit.exceptions import ( ComponentRole, clear_execution_context, - get_execution_context, with_execution_context, ) from pyrit.executor.core.strategy import Strategy, StrategyContext diff --git a/tests/unit/prompt_normalizer/test_prompt_normalizer.py b/tests/unit/prompt_normalizer/test_prompt_normalizer.py index b3d5b3f64..f0c017dcc 100644 --- a/tests/unit/prompt_normalizer/test_prompt_normalizer.py +++ b/tests/unit/prompt_normalizer/test_prompt_normalizer.py @@ -9,7 +9,13 @@ import pytest from unit.mocks import MockPromptTarget, get_image_message_piece -from pyrit.exceptions import EmptyResponseException +from pyrit.exceptions import ( + ComponentRole, + EmptyResponseException, + clear_execution_context, + get_execution_context, + with_execution_context, +) from pyrit.memory import CentralMemory from pyrit.models import ( Message, @@ -24,12 +30,6 @@ PromptConverter, StringJoinConverter, ) -from pyrit.exceptions import ( - ComponentRole, - clear_execution_context, - get_execution_context, - with_execution_context, -) from pyrit.prompt_normalizer import NormalizerRequest, PromptNormalizer from pyrit.prompt_normalizer.prompt_converter_configuration import ( PromptConverterConfiguration, @@ -438,6 +438,7 @@ async def test_send_prompt_async_exception_conv_id(mock_memory_instance, seed_gr # Tests for execution context in converter operations (used for error message handling) + class ContextCapturingConverter(PromptConverter): """A converter that captures the execution context during conversion.""" From 563999a9272fb6d37af8e7cec86453de6f51d8cd Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 16 Jan 2026 23:55:01 -0800 Subject: [PATCH 3/3] adding objective --- pyrit/exceptions/exception_context.py | 13 ++++++++ .../attack/multi_turn/chunked_request.py | 2 ++ pyrit/executor/attack/multi_turn/crescendo.py | 4 +++ .../attack/multi_turn/multi_prompt_sending.py | 2 ++ .../executor/attack/multi_turn/red_teaming.py | 3 ++ .../attack/multi_turn/tree_of_attacks.py | 10 ++++++ .../attack/single_turn/prompt_sending.py | 2 ++ .../unit/exceptions/test_exception_context.py | 31 +++++++++++++++++++ 8 files changed, 67 insertions(+) diff --git a/pyrit/exceptions/exception_context.py b/pyrit/exceptions/exception_context.py index 9ff68850d..6ad251295 100644 --- a/pyrit/exceptions/exception_context.py +++ b/pyrit/exceptions/exception_context.py @@ -74,6 +74,9 @@ class ExecutionContext: # The component class name (extracted from component_identifier.__type__ for quick access) component_name: Optional[str] = None + # The attack objective if available + objective: Optional[str] = None + def get_retry_context_string(self) -> str: """ Generate a concise context string for retry log messages. @@ -102,6 +105,13 @@ def get_exception_details(self) -> str: lines.append(f"Component: {self.component_role.value}") + if self.objective: + # Normalize to single line and truncate to 120 characters + objective_single_line = " ".join(self.objective.split()) + if len(objective_single_line) > 120: + objective_single_line = objective_single_line[:117] + "..." + lines.append(f"Objective: {objective_single_line}") + if self.objective_target_conversation_id: lines.append(f"Objective target conversation ID: {self.objective_target_conversation_id}") @@ -183,6 +193,7 @@ def with_execution_context( attack_identifier: Optional[Dict[str, Any]] = None, component_identifier: Optional[Dict[str, Any]] = None, objective_target_conversation_id: Optional[str] = None, + objective: Optional[str] = None, ) -> ExecutionContextManager: """ Create an execution context manager with the specified parameters. @@ -193,6 +204,7 @@ def with_execution_context( attack_identifier: The identifier from attack.get_identifier(). component_identifier: The identifier from component.get_identifier(). objective_target_conversation_id: The objective target conversation ID if available. + objective: The attack objective if available. Returns: ExecutionContextManager: A context manager that sets/clears the context. @@ -212,5 +224,6 @@ def with_execution_context( objective_target_conversation_id=objective_target_conversation_id, endpoint=endpoint, component_name=component_name, + objective=objective, ) return ExecutionContextManager(context=context) diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index 91a24ed6e..3dcf8a9fc 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -267,6 +267,7 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac attack_identifier=self.get_identifier(), component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=context.session.conversation_id, + objective=context.objective, ): response = await self._prompt_normalizer.send_prompt_async( message=message, @@ -362,6 +363,7 @@ async def _score_combined_value_async( attack_strategy_name=self.__class__.__name__, attack_identifier=self.get_identifier(), component_identifier=self._objective_scorer.get_identifier(), + objective=objective, ): scores = await self._objective_scorer.score_text_async(text=combined_value, objective=objective) return scores[0] if scores else None diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index a7e618161..d52f23128 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -514,6 +514,7 @@ async def _send_prompt_to_adversarial_chat_async( attack_identifier=self.get_identifier(), component_identifier=self._adversarial_chat.get_identifier(), objective_target_conversation_id=context.session.conversation_id, + objective=context.objective, ): response = await self._prompt_normalizer.send_prompt_async( message=message, @@ -597,6 +598,7 @@ async def _send_prompt_to_objective_target_async( attack_identifier=self.get_identifier(), component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=context.session.conversation_id, + objective=context.objective, ): response = await self._prompt_normalizer.send_prompt_async( message=attack_message, @@ -636,6 +638,7 @@ async def _check_refusal_async(self, context: CrescendoAttackContext, objective: attack_identifier=self.get_identifier(), component_identifier=self._refusal_scorer.get_identifier(), objective_target_conversation_id=context.session.conversation_id, + objective=context.objective, ): scores = await self._refusal_scorer.score_async( message=context.last_response, objective=objective, skip_on_error_result=False @@ -665,6 +668,7 @@ async def _score_response_async(self, *, context: CrescendoAttackContext) -> Sco attack_identifier=self.get_identifier(), component_identifier=self._objective_scorer.get_identifier(), objective_target_conversation_id=context.session.conversation_id, + objective=context.objective, ): scoring_results = await Scorer.score_response_async( response=context.last_response, diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index c3bf40db4..97c2a059a 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -341,6 +341,7 @@ async def _send_prompt_to_objective_target_async( attack_identifier=self.get_identifier(), component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=context.session.conversation_id, + objective=context.objective, ): return await self._prompt_normalizer.send_prompt_async( message=current_message, @@ -373,6 +374,7 @@ async def _evaluate_response_async(self, *, response: Message, objective: str) - attack_strategy_name=self.__class__.__name__, attack_identifier=self.get_identifier(), component_identifier=self._objective_scorer.get_identifier() if self._objective_scorer else None, + objective=objective, ): scoring_results = await Scorer.score_response_async( response=response, diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index a121e714a..8ee68b4c3 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -368,6 +368,7 @@ async def _generate_next_prompt_async(self, context: MultiTurnAttackContext[Any] attack_identifier=self.get_identifier(), component_identifier=self._adversarial_chat.get_identifier(), objective_target_conversation_id=context.session.conversation_id, + objective=context.objective, ): response = await self._prompt_normalizer.send_prompt_async( message=prompt_message, @@ -526,6 +527,7 @@ async def _send_prompt_to_objective_target_async( attack_identifier=self.get_identifier(), component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=context.session.conversation_id, + objective=context.objective, ): # Send the message to the target response = await self._prompt_normalizer.send_prompt_async( @@ -573,6 +575,7 @@ async def _score_response_async(self, *, context: MultiTurnAttackContext[Any]) - attack_identifier=self.get_identifier(), component_identifier=self._objective_scorer.get_identifier(), objective_target_conversation_id=context.session.conversation_id, + objective=context.objective, ): # score_async handles blocked, filtered, other errors scoring_results = await self._objective_scorer.score_async( diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 74c9fa86a..a1a639943 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -341,6 +341,9 @@ def __init__( # This supports multimodal messages self._initial_prompt: Optional[Message] = initial_prompt + # Current objective (set when send_prompt_async is called) + self._objective: Optional[str] = None + async def initialize_with_prepended_conversation_async( self, *, @@ -424,6 +427,9 @@ async def send_prompt_async(self, objective: str) -> None: - `off_topic`: `True` if the prompt was deemed off-topic after all retries - `error_message`: Set if an error occurred during execution """ + # Store objective for use in execution context + self._objective = objective + try: # Check if we have an initial prompt to use (bypasses adversarial generation) if self._initial_prompt and self._is_first_turn(): @@ -525,6 +531,7 @@ async def _send_prompt_to_target_async(self, prompt: str) -> Message: attack_identifier=self._attack_id, component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=self.objective_target_conversation_id, + objective=self._objective, ): response = await self._prompt_normalizer.send_prompt_async( message=message, @@ -580,6 +587,7 @@ async def _send_initial_prompt_to_target_async(self) -> Message: attack_identifier=self._attack_id, component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=self.objective_target_conversation_id, + objective=self._objective, ): response = await self._prompt_normalizer.send_prompt_async( message=message, @@ -633,6 +641,7 @@ async def _score_response_async(self, *, response: Message, objective: str) -> N attack_identifier=self._attack_id, component_identifier=self._objective_scorer.get_identifier(), objective_target_conversation_id=self.objective_target_conversation_id, + objective=objective, ): scoring_results = await Scorer.score_response_async( response=response, @@ -1078,6 +1087,7 @@ async def _send_to_adversarial_chat_async(self, prompt_text: str) -> str: attack_identifier=self._attack_id, component_identifier=self._adversarial_chat.get_identifier(), objective_target_conversation_id=self.objective_target_conversation_id, + objective=self._objective, ): response = await self._prompt_normalizer.send_prompt_async( message=message, diff --git a/pyrit/executor/attack/single_turn/prompt_sending.py b/pyrit/executor/attack/single_turn/prompt_sending.py index 9aa6f2892..07d65cb4c 100644 --- a/pyrit/executor/attack/single_turn/prompt_sending.py +++ b/pyrit/executor/attack/single_turn/prompt_sending.py @@ -318,6 +318,7 @@ async def _send_prompt_to_objective_target_async( attack_identifier=self.get_identifier(), component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=context.conversation_id, + objective=context.params.objective, ): return await self._prompt_normalizer.send_prompt_async( message=message, @@ -355,6 +356,7 @@ async def _evaluate_response_async( attack_strategy_name=self.__class__.__name__, attack_identifier=self.get_identifier(), component_identifier=self._objective_scorer.get_identifier() if self._objective_scorer else None, + objective=objective, ): scoring_results = await Scorer.score_response_async( response=response, diff --git a/tests/unit/exceptions/test_exception_context.py b/tests/unit/exceptions/test_exception_context.py index 99c4b4c7b..751543f61 100644 --- a/tests/unit/exceptions/test_exception_context.py +++ b/tests/unit/exceptions/test_exception_context.py @@ -27,6 +27,7 @@ def test_default_values(self): assert context.objective_target_conversation_id is None assert context.endpoint is None assert context.component_name is None + assert context.objective is None def test_initialization_with_values(self): """Test ExecutionContext initialization with all values.""" @@ -38,6 +39,7 @@ def test_initialization_with_values(self): objective_target_conversation_id="conv-123", endpoint="https://api.openai.com", component_name="OpenAIChatTarget", + objective="Tell me how to hack a system", ) assert context.component_role == ComponentRole.OBJECTIVE_TARGET assert context.attack_strategy_name == "PromptSendingAttack" @@ -46,6 +48,7 @@ def test_initialization_with_values(self): assert context.objective_target_conversation_id == "conv-123" assert context.endpoint == "https://api.openai.com" assert context.component_name == "OpenAIChatTarget" + assert context.objective == "Tell me how to hack a system" def test_get_retry_context_string_minimal(self): """Test retry context string with only component role.""" @@ -99,14 +102,42 @@ def test_get_exception_details_full(self): attack_identifier={"__type__": "RedTeamingAttack", "id": "xyz"}, component_identifier={"__type__": "OpenAIChatTarget"}, objective_target_conversation_id="conv-456", + objective="Tell me how to hack a system", ) result = context.get_exception_details() assert "Attack: RedTeamingAttack" in result assert "Component: objective_target" in result + assert "Objective: Tell me how to hack a system" in result assert "Objective target conversation ID: conv-456" in result assert "Attack identifier:" in result assert "objective_target identifier:" in result + def test_get_exception_details_objective_truncation(self): + """Test that long objectives are truncated to 120 characters.""" + long_objective = "A" * 200 # 200 character objective + context = ExecutionContext( + component_role=ComponentRole.OBJECTIVE_TARGET, + objective=long_objective, + ) + result = context.get_exception_details() + # Should be truncated to 117 chars + "..." + assert "Objective: " + "A" * 117 + "..." in result + # Full objective should not appear + assert long_objective not in result + + def test_get_exception_details_objective_single_line(self): + """Test that objectives with newlines are normalized to single line.""" + multiline_objective = "Tell me how to\nhack a system\nwith multiple lines" + context = ExecutionContext( + component_role=ComponentRole.OBJECTIVE_TARGET, + objective=multiline_objective, + ) + result = context.get_exception_details() + # Should be on single line with spaces instead of newlines + assert "Objective: Tell me how to hack a system with multiple lines" in result + # No newlines in the objective line + assert "\nhack" not in result + class TestExecutionContextFunctions: """Tests for the context management functions."""