Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 30 additions & 27 deletions sentry_sdk/integrations/google_genai/streaming.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
from typing import (
TYPE_CHECKING,
Any,
List,
TypedDict,
Optional,
)
from typing import TYPE_CHECKING, Any, List, TypedDict, Optional, Union

from sentry_sdk.ai.utils import set_data_normalized
from sentry_sdk.consts import SPANDATA
Expand All @@ -31,7 +25,21 @@ class AccumulatedResponse(TypedDict):
text: str
finish_reasons: "List[str]"
tool_calls: "List[dict[str, Any]]"
usage_metadata: "UsageData"
usage_metadata: "Optional[UsageData]"


def element_wise_usage_max(self: "UsageData", other: "UsageData") -> "UsageData":
return UsageData(
input_tokens=max(self["input_tokens"], other["input_tokens"]),
output_tokens=max(self["output_tokens"], other["output_tokens"]),
input_tokens_cached=max(
self["input_tokens_cached"], other["input_tokens_cached"]
),
output_tokens_reasoning=max(
self["output_tokens_reasoning"], other["output_tokens_reasoning"]
),
total_tokens=max(self["total_tokens"], other["total_tokens"]),
)


def accumulate_streaming_response(
Expand All @@ -41,11 +49,7 @@ def accumulate_streaming_response(
accumulated_text = []
finish_reasons = []
tool_calls = []
total_input_tokens = 0
total_output_tokens = 0
total_tokens = 0
total_cached_tokens = 0
total_reasoning_tokens = 0
usage_data = None
response_id = None
model = None

Expand All @@ -68,25 +72,21 @@ def accumulate_streaming_response(
if extracted_tool_calls:
tool_calls.extend(extracted_tool_calls)

# Accumulate token usage
extracted_usage_data = extract_usage_data(chunk)
total_input_tokens += extracted_usage_data["input_tokens"]
total_output_tokens += extracted_usage_data["output_tokens"]
total_cached_tokens += extracted_usage_data["input_tokens_cached"]
total_reasoning_tokens += extracted_usage_data["output_tokens_reasoning"]
total_tokens += extracted_usage_data["total_tokens"]
# Use last possible chunk, in case of interruption, and
# gracefully handle missing intermediate tokens by taking maximum
# with previous token reporting.
chunk_usage_data = extract_usage_data(chunk)
usage_data = (
chunk_usage_data
if usage_data is None
else element_wise_usage_max(usage_data, chunk_usage_data)
)

accumulated_response = AccumulatedResponse(
text="".join(accumulated_text),
finish_reasons=finish_reasons,
tool_calls=tool_calls,
usage_metadata=UsageData(
input_tokens=total_input_tokens,
output_tokens=total_output_tokens,
input_tokens_cached=total_cached_tokens,
output_tokens_reasoning=total_reasoning_tokens,
total_tokens=total_tokens,
),
usage_metadata=usage_data,
id=response_id,
model=model,
)
Expand Down Expand Up @@ -126,6 +126,9 @@ def set_span_data_for_streaming_response(
if accumulated_response.get("model"):
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, accumulated_response["model"])

if accumulated_response["usage_metadata"] is None:
return

if accumulated_response["usage_metadata"]["input_tokens"]:
span.set_data(
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS,
Expand Down
23 changes: 8 additions & 15 deletions tests/integrations/google_genai/test_google_genai.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,13 +452,13 @@ def test_streaming_generate_content(sentry_init, capture_events, mock_genai_clie
"usageMetadata": {
"promptTokenCount": 10,
"candidatesTokenCount": 2,
"totalTokenCount": 12, # Not set in intermediate chunks
"totalTokenCount": 12,
},
"responseId": "response-id-stream-123",
"modelVersion": "gemini-1.5-flash",
}

# Chunk 2: Second part of text with more usage metadata
# Chunk 2: Second part of text with intermediate usage metadata
chunk2_json = {
"candidates": [
{
Expand Down Expand Up @@ -545,25 +545,18 @@ def test_streaming_generate_content(sentry_init, capture_events, mock_genai_clie
assert chat_span["data"][SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS] == "STOP"
assert invoke_span["data"][SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS] == "STOP"

# Verify token counts - should reflect accumulated values
# Input tokens: max of all chunks = 10
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 30
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 30
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10

# Output tokens: candidates (2 + 3 + 7 = 12) + reasoning (3) = 15
# Note: output_tokens includes both candidates and reasoning tokens
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 15
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 15
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10

# Total tokens: from the last chunk
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 50
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 50
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 25
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 25

# Cached tokens: max of all chunks = 5
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 5
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 5

# Reasoning tokens: sum of thoughts_token_count = 3
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING] == 3
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING] == 3

Expand Down
Loading