From 9bb872d5114c576866b41c11742bb592f0d4fc1a Mon Sep 17 00:00:00 2001 From: Alexander Alderman Webb Date: Mon, 2 Feb 2026 10:23:41 +0100 Subject: [PATCH] fix(google-genai): Token reporting --- .../integrations/google_genai/streaming.py | 57 ++++++++++--------- .../google_genai/test_google_genai.py | 23 +++----- 2 files changed, 38 insertions(+), 42 deletions(-) diff --git a/sentry_sdk/integrations/google_genai/streaming.py b/sentry_sdk/integrations/google_genai/streaming.py index 5bd8890d02..8649ce2ac0 100644 --- a/sentry_sdk/integrations/google_genai/streaming.py +++ b/sentry_sdk/integrations/google_genai/streaming.py @@ -1,10 +1,4 @@ -from typing import ( - TYPE_CHECKING, - Any, - List, - TypedDict, - Optional, -) +from typing import TYPE_CHECKING, Any, List, TypedDict, Optional, Union from sentry_sdk.ai.utils import set_data_normalized from sentry_sdk.consts import SPANDATA @@ -31,7 +25,21 @@ class AccumulatedResponse(TypedDict): text: str finish_reasons: "List[str]" tool_calls: "List[dict[str, Any]]" - usage_metadata: "UsageData" + usage_metadata: "Optional[UsageData]" + + +def element_wise_usage_max(self: "UsageData", other: "UsageData") -> "UsageData": + return UsageData( + input_tokens=max(self["input_tokens"], other["input_tokens"]), + output_tokens=max(self["output_tokens"], other["output_tokens"]), + input_tokens_cached=max( + self["input_tokens_cached"], other["input_tokens_cached"] + ), + output_tokens_reasoning=max( + self["output_tokens_reasoning"], other["output_tokens_reasoning"] + ), + total_tokens=max(self["total_tokens"], other["total_tokens"]), + ) def accumulate_streaming_response( @@ -41,11 +49,7 @@ def accumulate_streaming_response( accumulated_text = [] finish_reasons = [] tool_calls = [] - total_input_tokens = 0 - total_output_tokens = 0 - total_tokens = 0 - total_cached_tokens = 0 - total_reasoning_tokens = 0 + usage_data = None response_id = None model = None @@ -68,25 +72,21 @@ def accumulate_streaming_response( if extracted_tool_calls: tool_calls.extend(extracted_tool_calls) - # Accumulate token usage - extracted_usage_data = extract_usage_data(chunk) - total_input_tokens += extracted_usage_data["input_tokens"] - total_output_tokens += extracted_usage_data["output_tokens"] - total_cached_tokens += extracted_usage_data["input_tokens_cached"] - total_reasoning_tokens += extracted_usage_data["output_tokens_reasoning"] - total_tokens += extracted_usage_data["total_tokens"] + # Use last possible chunk, in case of interruption, and + # gracefully handle missing intermediate tokens by taking maximum + # with previous token reporting. + chunk_usage_data = extract_usage_data(chunk) + usage_data = ( + chunk_usage_data + if usage_data is None + else element_wise_usage_max(usage_data, chunk_usage_data) + ) accumulated_response = AccumulatedResponse( text="".join(accumulated_text), finish_reasons=finish_reasons, tool_calls=tool_calls, - usage_metadata=UsageData( - input_tokens=total_input_tokens, - output_tokens=total_output_tokens, - input_tokens_cached=total_cached_tokens, - output_tokens_reasoning=total_reasoning_tokens, - total_tokens=total_tokens, - ), + usage_metadata=usage_data, id=response_id, model=model, ) @@ -126,6 +126,9 @@ def set_span_data_for_streaming_response( if accumulated_response.get("model"): span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, accumulated_response["model"]) + if accumulated_response["usage_metadata"] is None: + return + if accumulated_response["usage_metadata"]["input_tokens"]: span.set_data( SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, diff --git a/tests/integrations/google_genai/test_google_genai.py b/tests/integrations/google_genai/test_google_genai.py index 37ba50420f..7510115bfa 100644 --- a/tests/integrations/google_genai/test_google_genai.py +++ b/tests/integrations/google_genai/test_google_genai.py @@ -452,13 +452,13 @@ def test_streaming_generate_content(sentry_init, capture_events, mock_genai_clie "usageMetadata": { "promptTokenCount": 10, "candidatesTokenCount": 2, - "totalTokenCount": 12, # Not set in intermediate chunks + "totalTokenCount": 12, }, "responseId": "response-id-stream-123", "modelVersion": "gemini-1.5-flash", } - # Chunk 2: Second part of text with more usage metadata + # Chunk 2: Second part of text with intermediate usage metadata chunk2_json = { "candidates": [ { @@ -545,25 +545,18 @@ def test_streaming_generate_content(sentry_init, capture_events, mock_genai_clie assert chat_span["data"][SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS] == "STOP" assert invoke_span["data"][SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS] == "STOP" - # Verify token counts - should reflect accumulated values - # Input tokens: max of all chunks = 10 - assert chat_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 30 - assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 30 + assert chat_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10 + assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10 - # Output tokens: candidates (2 + 3 + 7 = 12) + reasoning (3) = 15 - # Note: output_tokens includes both candidates and reasoning tokens - assert chat_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 15 - assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 15 + assert chat_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10 + assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10 - # Total tokens: from the last chunk - assert chat_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 50 - assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 50 + assert chat_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 25 + assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 25 - # Cached tokens: max of all chunks = 5 assert chat_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 5 assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 5 - # Reasoning tokens: sum of thoughts_token_count = 3 assert chat_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING] == 3 assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING] == 3