Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 75 additions & 57 deletions sentry_sdk/integrations/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@
from sentry_sdk._types import TextPart


class _RecordedUsage:
output_tokens: int = 0
input_tokens: int = 0
cache_write_input_tokens: "Optional[int]" = 0
cache_read_input_tokens: "Optional[int]" = 0


class AnthropicIntegration(Integration):
identifier = "anthropic"
origin = f"auto.ai.{identifier}"
Expand Down Expand Up @@ -112,31 +119,15 @@ def _get_token_usage(result: "Messages") -> "tuple[int, int, int, int]":
def _collect_ai_data(
event: "MessageStreamEvent",
model: "str | None",
input_tokens: int,
output_tokens: int,
cache_read_input_tokens: int,
cache_write_input_tokens: int,
usage: "_RecordedUsage",
content_blocks: "list[str]",
) -> "tuple[str | None, int, int, int, int, list[str]]":
) -> "tuple[str | None, _RecordedUsage, list[str]]":
"""
Collect model information, token usage, and collect content blocks from the AI streaming response.
"""
with capture_internal_exceptions():
if hasattr(event, "type"):
if event.type == "message_start":
usage = event.message.usage
input_tokens += usage.input_tokens
output_tokens += usage.output_tokens
if hasattr(usage, "cache_read_input_tokens") and isinstance(
usage.cache_read_input_tokens, int
):
cache_read_input_tokens += usage.cache_read_input_tokens
if hasattr(usage, "cache_creation_input_tokens") and isinstance(
usage.cache_creation_input_tokens, int
):
cache_write_input_tokens += usage.cache_creation_input_tokens
model = event.message.model or model
elif event.type == "content_block_start":
if event.type == "content_block_start":
pass
elif event.type == "content_block_delta":
if hasattr(event.delta, "text"):
Expand All @@ -145,15 +136,60 @@ def _collect_ai_data(
content_blocks.append(event.delta.partial_json)
elif event.type == "content_block_stop":
pass
elif event.type == "message_delta":
output_tokens += event.usage.output_tokens

# Token counting logic mirrors anthropic SDK, which also extracts already accumulated tokens.
# https://github.com/anthropics/anthropic-sdk-python/blob/9c485f6966e10ae0ea9eabb3a921d2ea8145a25b/src/anthropic/lib/streaming/_messages.py#L433-L518
if event.type == "message_start":
model = event.message.model or model

incoming_usage = event.message.usage
usage.output_tokens = incoming_usage.output_tokens
usage.input_tokens = incoming_usage.input_tokens

usage.cache_write_input_tokens = getattr(
incoming_usage, "cache_creation_input_tokens", None
)
usage.cache_read_input_tokens = getattr(
incoming_usage, "cache_read_input_tokens", None
)

return (
model,
usage,
content_blocks,
)

# Counterintuitive, but message_delta contains cumulative token counts :)
if event.type == "message_delta":
usage.output_tokens = event.usage.output_tokens

# Update other usage fields if they exist in the event
input_tokens = getattr(event.usage, "input_tokens", None)
if input_tokens is not None:
usage.input_tokens = input_tokens

cache_creation_input_tokens = getattr(
event.usage, "cache_creation_input_tokens", None
)
if cache_creation_input_tokens is not None:
usage.cache_write_input_tokens = cache_creation_input_tokens

cache_read_input_tokens = getattr(
event.usage, "cache_read_input_tokens", None
)
if cache_read_input_tokens is not None:
usage.cache_read_input_tokens = cache_read_input_tokens
# TODO: Record event.usage.server_tool_use

return (
model,
usage,
content_blocks,
)

return (
model,
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
usage,
content_blocks,
)

Expand Down Expand Up @@ -414,27 +450,18 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A

def new_iterator() -> "Iterator[MessageStreamEvent]":
model = None
input_tokens = 0
output_tokens = 0
cache_read_input_tokens = 0
cache_write_input_tokens = 0
usage = _RecordedUsage()
content_blocks: "list[str]" = []

for event in old_iterator:
(
model,
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
usage,
content_blocks,
) = _collect_ai_data(
event,
model,
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
usage,
content_blocks,
)
yield event
Expand All @@ -443,37 +470,28 @@ def new_iterator() -> "Iterator[MessageStreamEvent]":
span=span,
integration=integration,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_input_tokens=cache_read_input_tokens,
cache_write_input_tokens=cache_write_input_tokens,
input_tokens=usage.input_tokens,
output_tokens=usage.output_tokens,
cache_read_input_tokens=usage.cache_read_input_tokens,
cache_write_input_tokens=usage.cache_write_input_tokens,
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
finish_span=True,
)

async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
model = None
input_tokens = 0
output_tokens = 0
cache_read_input_tokens = 0
cache_write_input_tokens = 0
usage = _RecordedUsage()
content_blocks: "list[str]" = []

async for event in old_iterator:
(
model,
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
usage,
content_blocks,
) = _collect_ai_data(
event,
model,
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
usage,
content_blocks,
)
yield event
Expand All @@ -482,10 +500,10 @@ async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
span=span,
integration=integration,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_input_tokens=cache_read_input_tokens,
cache_write_input_tokens=cache_write_input_tokens,
input_tokens=usage.input_tokens,
output_tokens=usage.output_tokens,
cache_read_input_tokens=usage.cache_read_input_tokens,
cache_write_input_tokens=usage.cache_write_input_tokens,
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
finish_span=True,
)
Expand Down
42 changes: 22 additions & 20 deletions tests/integrations/anthropic/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ async def __call__(self, *args, **kwargs):
_set_output_data,
_collect_ai_data,
_transform_anthropic_content_block,
_RecordedUsage,
)
from sentry_sdk.ai.utils import transform_content_part, transform_message_content
from sentry_sdk.utils import package_version
Expand Down Expand Up @@ -307,8 +308,8 @@ def test_streaming_create_message(
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]

assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 30
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 40
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True


Expand Down Expand Up @@ -412,8 +413,8 @@ async def test_streaming_create_message_async(
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]

assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 30
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 40
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True


Expand Down Expand Up @@ -546,8 +547,8 @@ def test_streaming_create_message_with_input_json_delta(
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]

assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 366
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 51
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 417
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 41
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 407
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True


Expand Down Expand Up @@ -688,8 +689,8 @@ async def test_streaming_create_message_with_input_json_delta_async(
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]

assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 366
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 51
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 417
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 41
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 407
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True


Expand Down Expand Up @@ -849,18 +850,19 @@ def test_collect_ai_data_with_input_json_delta():
type="content_block_delta",
)
model = None
input_tokens = 10
output_tokens = 20

usage = _RecordedUsage()
usage.output_tokens = 20
usage.input_tokens = 10

content_blocks = []

model, new_input_tokens, new_output_tokens, _, _, new_content_blocks = (
_collect_ai_data(
event, model, input_tokens, output_tokens, 0, 0, content_blocks
)
model, new_usage, new_content_blocks = _collect_ai_data(
event, model, usage, content_blocks
)
assert model is None
assert new_input_tokens == input_tokens
assert new_output_tokens == output_tokens
assert new_usage.input_tokens == usage.input_tokens
assert new_usage.output_tokens == usage.output_tokens
assert new_content_blocks == ["test"]


Expand Down Expand Up @@ -1345,8 +1347,8 @@ def test_streaming_create_message_with_system_prompt(
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]

assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 30
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 40
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True


Expand Down Expand Up @@ -1465,8 +1467,8 @@ async def test_streaming_create_message_with_system_prompt_async(
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]

assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 30
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 40
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True


Expand Down
Loading