Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion pyrit/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,26 @@
pyrit_placeholder_retry,
pyrit_target_retry,
)
from pyrit.exceptions.exception_context import (
ComponentRole,
ExecutionContext,
ExecutionContextManager,
clear_execution_context,
get_execution_context,
set_execution_context,
with_execution_context,
)
from pyrit.exceptions.exceptions_helpers import remove_markdown_json

__all__ = [
"BadRequestException",
"clear_execution_context",
"ComponentRole",
"EmptyResponseException",
"ExecutionContext",
"ExecutionContextManager",
"get_execution_context",
"get_retry_max_num_attempts",
"handle_bad_request_exception",
"InvalidJsonException",
"MissingPromptPlaceholderException",
Expand All @@ -30,5 +45,6 @@
"pyrit_placeholder_retry",
"RateLimitException",
"remove_markdown_json",
"get_retry_max_num_attempts",
"set_execution_context",
"with_execution_context",
]
229 changes: 229 additions & 0 deletions pyrit/exceptions/exception_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Context management for enhanced exception and retry logging in PyRIT.

This module provides a contextvar-based system for tracking which component
(objective_target, adversarial_chat, objective_scorer, etc.) is currently
executing, allowing retry decorators and exception handlers to include
meaningful context in their messages.
"""

from contextvars import ContextVar
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Optional


class ComponentRole(Enum):
"""
Identifies the role of a component within an attack execution.

This enum is used to provide meaningful context in error messages and retry logs,
helping users identify which part of an attack encountered an issue.
"""

# Core attack components
OBJECTIVE_TARGET = "objective_target"
ADVERSARIAL_CHAT = "adversarial_chat"

# Scoring components
OBJECTIVE_SCORER = "objective_scorer"
OBJECTIVE_SCORER_TARGET = "objective_scorer_target"
REFUSAL_SCORER = "refusal_scorer"
REFUSAL_SCORER_TARGET = "refusal_scorer_target"
AUXILIARY_SCORER = "auxiliary_scorer"
AUXILIARY_SCORER_TARGET = "auxiliary_scorer_target"

# Conversion components
CONVERTER = "converter"
CONVERTER_TARGET = "converter_target"

# Other components
UNKNOWN = "unknown"


@dataclass
class ExecutionContext:
"""
Holds context information about the currently executing component.

This context is used to enrich error messages and retry logs with
information about which component failed and its configuration.
"""

# The role of the component (e.g., objective_scorer, adversarial_chat)
component_role: ComponentRole = ComponentRole.UNKNOWN

# The attack strategy class name (e.g., "PromptSendingAttack")
attack_strategy_name: Optional[str] = None

# The identifier from the attack strategy's get_identifier()
attack_identifier: Optional[Dict[str, Any]] = None

# The identifier from the component's get_identifier() (target, scorer, etc.)
component_identifier: Optional[Dict[str, Any]] = None

# The objective target conversation ID if available
objective_target_conversation_id: Optional[str] = None

# The endpoint/URI if available (extracted from component_identifier for quick access)
endpoint: Optional[str] = None

# The component class name (extracted from component_identifier.__type__ for quick access)
component_name: Optional[str] = None

# The attack objective if available
objective: Optional[str] = None

def get_retry_context_string(self) -> str:
"""
Generate a concise context string for retry log messages.

Returns:
str: A formatted string with component role, component name, and endpoint.
"""
parts = [self.component_role.value]
if self.component_name:
parts.append(f"({self.component_name})")
if self.endpoint:
parts.append(f"endpoint: {self.endpoint}")
return " ".join(parts)

def get_exception_details(self) -> str:
"""
Generate detailed exception context for error messages.

Returns:
str: A multi-line formatted string with full context details.
"""
lines = []

if self.attack_strategy_name:
lines.append(f"Attack: {self.attack_strategy_name}")

lines.append(f"Component: {self.component_role.value}")

if self.objective:
# Normalize to single line and truncate to 120 characters
objective_single_line = " ".join(self.objective.split())
if len(objective_single_line) > 120:
objective_single_line = objective_single_line[:117] + "..."
lines.append(f"Objective: {objective_single_line}")

if self.objective_target_conversation_id:
lines.append(f"Objective target conversation ID: {self.objective_target_conversation_id}")

if self.attack_identifier:
lines.append(f"Attack identifier: {self.attack_identifier}")

if self.component_identifier:
lines.append(f"{self.component_role.value} identifier: {self.component_identifier}")

return "\n".join(lines)


# The contextvar that stores the current execution context
_execution_context: ContextVar[Optional[ExecutionContext]] = ContextVar("execution_context", default=None)


def get_execution_context() -> Optional[ExecutionContext]:
"""
Get the current execution context.

Returns:
Optional[ExecutionContext]: The current context, or None if not set.
"""
return _execution_context.get()


def set_execution_context(context: ExecutionContext) -> None:
"""
Set the current execution context.

Args:
context: The execution context to set.
"""
_execution_context.set(context)


def clear_execution_context() -> None:
"""Clear the current execution context."""
_execution_context.set(None)


@dataclass
class ExecutionContextManager:
"""
A context manager for setting execution context during component operations.

This class provides a convenient way to set and automatically clear
execution context when entering and exiting a code block.

On successful exit, the context is restored to its previous value.
On exception, the context is preserved so exception handlers can access it.
"""

context: ExecutionContext
_token: Any = field(default=None, init=False, repr=False)

def __enter__(self) -> "ExecutionContextManager":
"""Set the execution context on entry."""
self._token = _execution_context.set(self.context)
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""
Restore the previous context on exit.

If an exception occurred, the context is preserved so that exception
handlers higher in the call stack can access it for enhanced error messages.
"""
if exc_type is None:
# No exception - restore previous context
_execution_context.reset(self._token)
# On exception, leave context in place for exception handlers to read


def with_execution_context(
*,
component_role: ComponentRole,
attack_strategy_name: Optional[str] = None,
attack_identifier: Optional[Dict[str, Any]] = None,
component_identifier: Optional[Dict[str, Any]] = None,
objective_target_conversation_id: Optional[str] = None,
objective: Optional[str] = None,
) -> ExecutionContextManager:
"""
Create an execution context manager with the specified parameters.

Args:
component_role: The role of the component being executed.
attack_strategy_name: The name of the attack strategy class.
attack_identifier: The identifier from attack.get_identifier().
component_identifier: The identifier from component.get_identifier().
objective_target_conversation_id: The objective target conversation ID if available.
objective: The attack objective if available.

Returns:
ExecutionContextManager: A context manager that sets/clears the context.
"""
# Extract endpoint and component_name from component_identifier if available
endpoint = None
component_name = None
if component_identifier:
endpoint = component_identifier.get("endpoint")
component_name = component_identifier.get("__type__")

context = ExecutionContext(
component_role=component_role,
attack_strategy_name=attack_strategy_name,
attack_identifier=attack_identifier,
component_identifier=component_identifier,
objective_target_conversation_id=objective_target_conversation_id,
endpoint=endpoint,
component_name=component_name,
objective=objective,
)
return ExecutionContextManager(context=context)
61 changes: 51 additions & 10 deletions pyrit/exceptions/exceptions_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,57 @@


def log_exception(retry_state: RetryCallState) -> None:
# Log each retry attempt with exception details at ERROR level
elapsed_time = time.monotonic() - retry_state.start_time
call_count = retry_state.attempt_number

if retry_state.outcome.failed:
exception = retry_state.outcome.exception()
logger.error(
f"Retry attempt {call_count} for {retry_state.fn.__name__} failed with exception: {exception}. "
f"Elapsed time: {elapsed_time} seconds. Total calls: {call_count}"
)
"""
Log each retry attempt with exception details at ERROR level.

If an execution context is set (via exception_context module), includes
component role and endpoint information for easier debugging.

Args:
retry_state: The tenacity retry state containing attempt information.
"""
# Validate retry_state has required attributes before proceeding
if not retry_state:
logger.error("Retry callback invoked with no retry state")
return

# Safely extract values with defaults
call_count = getattr(retry_state, "attempt_number", None) or 0
start_time = getattr(retry_state, "start_time", None)
elapsed_time = (time.monotonic() - start_time) if start_time is not None else 0.0

outcome = getattr(retry_state, "outcome", None)
if not outcome or not getattr(outcome, "failed", False):
return

exception = outcome.exception() if hasattr(outcome, "exception") else None

# Get function name safely
fn = getattr(retry_state, "fn", None)
fn_name = getattr(fn, "__name__", "unknown") if fn else "unknown"

# Build the "for X" part of the message based on execution context
for_clause = fn_name
try:
from pyrit.exceptions.exception_context import get_execution_context

exec_context = get_execution_context()
if exec_context:
# Format: "objective scorer; TrueFalseScorer::_score_value_with_llm"
role_display = exec_context.component_role.value.replace("_", " ")
if exec_context.component_name:
for_clause = f"{role_display}. {exec_context.component_name}::{fn_name}"
else:
for_clause = f"{role_display}. {fn_name}"
except Exception:
# Don't let context retrieval errors break retry logging
pass

logger.error(
f"Retry attempt {call_count} for {for_clause} "
f"failed with exception: {exception}. "
f"Elapsed time: {elapsed_time} seconds. Total calls: {call_count}"
)


def remove_start_md_json(response_msg: str) -> str:
Expand Down
34 changes: 25 additions & 9 deletions pyrit/executor/attack/multi_turn/chunked_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, List, Optional

from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults
from pyrit.exceptions import ComponentRole, with_execution_context
from pyrit.executor.attack.component import ConversationManager
from pyrit.executor.attack.core.attack_config import (
AttackConverterConfig,
Expand Down Expand Up @@ -260,15 +261,23 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac
message = Message.from_prompt(prompt=chunk_prompt, role="user")

# Send the prompt using the normalizer
response = await self._prompt_normalizer.send_prompt_async(
message=message,
target=self._objective_target,
conversation_id=context.session.conversation_id,
request_converter_configurations=self._request_converters,
response_converter_configurations=self._response_converters,
labels=context.memory_labels,
with with_execution_context(
component_role=ComponentRole.OBJECTIVE_TARGET,
attack_strategy_name=self.__class__.__name__,
attack_identifier=self.get_identifier(),
)
component_identifier=self._objective_target.get_identifier(),
objective_target_conversation_id=context.session.conversation_id,
objective=context.objective,
):
response = await self._prompt_normalizer.send_prompt_async(
message=message,
target=self._objective_target,
conversation_id=context.session.conversation_id,
request_converter_configurations=self._request_converters,
response_converter_configurations=self._response_converters,
labels=context.memory_labels,
attack_identifier=self.get_identifier(),
)

# Store the response
if response:
Expand Down Expand Up @@ -349,7 +358,14 @@ async def _score_combined_value_async(
if not self._objective_scorer:
return None

scores = await self._objective_scorer.score_text_async(text=combined_value, objective=objective)
with with_execution_context(
component_role=ComponentRole.OBJECTIVE_SCORER,
attack_strategy_name=self.__class__.__name__,
attack_identifier=self.get_identifier(),
component_identifier=self._objective_scorer.get_identifier(),
objective=objective,
):
scores = await self._objective_scorer.score_text_async(text=combined_value, objective=objective)
return scores[0] if scores else None

async def _teardown_async(self, *, context: ChunkedRequestAttackContext) -> None:
Expand Down
Loading