From d8576833ca4f3ed551a0c0a1bb2dd82992c3eeec Mon Sep 17 00:00:00 2001 From: Faraaz1994 Date: Wed, 28 Jan 2026 14:52:28 +0530 Subject: [PATCH] feat(http_options): add dynamic HTTP options support via RunConfig Enable per-request HTTP configuration (headers, timeout, retry_options) to be passed via RunConfig and propagated through the request pipeline to models. --- src/google/adk/agents/run_config.py | 5 +- .../adk/flows/llm_flows/base_llm_flow.py | 17 +++- src/google/adk/flows/llm_flows/basic.py | 34 ++++++- src/google/adk/models/lite_llm.py | 59 +++-------- .../flows/llm_flows/test_basic_processor.py | 47 +++++++++ tests/unittests/models/test_litellm.py | 97 +++++++++++++++++++ 6 files changed, 206 insertions(+), 53 deletions(-) diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index a6f22c0bf6..343d8fb459 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -30,7 +30,6 @@ logger = logging.getLogger('google_adk.' + __name__) - class StreamingMode(Enum): """Streaming modes for agent execution. @@ -160,7 +159,6 @@ class StreamingMode(Enum): For bidirectional streaming, use runner.run_live() instead of run_async(). """ - class RunConfig(BaseModel): """Configs for runtime behavior of agents. @@ -175,6 +173,9 @@ class RunConfig(BaseModel): speech_config: Optional[types.SpeechConfig] = None """Speech configuration for the live agent.""" + http_options: Optional[types.HttpOptions] = None + """HTTP options for the agent execution (e.g. custom headers).""" + response_modalities: Optional[list[str]] = None """The output modalities. If not set, it's default to AUDIO.""" diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 759ac532fd..ff37a37592 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -67,7 +67,6 @@ # Statistics configuration DEFAULT_ENABLE_CACHE_STATISTICS = False - class BaseLlmFlow(ABC): """A basic flow that calls the LLM in a loop until a final response is generated. @@ -483,6 +482,22 @@ async def _preprocess_async( f'Expected agent to be an LlmAgent, but got {type(agent)}' ) + # Propagate http_options from RunConfig to LlmRequest as defaults. + # Request-level settings (from callbacks/processors) take precedence. + if ( + invocation_context.run_config + and invocation_context.run_config.http_options + ): + run_opts = invocation_context.run_config.http_options + if not llm_request.config.http_options: + # Deep-copy to avoid mutating the user's RunConfig across steps. + llm_request.config.http_options = run_opts.model_copy(deep=True) + elif run_opts.headers: + # Merge headers: request-level headers win (use setdefault). + if not llm_request.config.http_options.headers: + llm_request.config.http_options.headers = {} + for key, value in run_opts.headers.items(): + llm_request.config.http_options.headers.setdefault(key, value) # Runs processors. for processor in self.request_processors: async with Aclosing( diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py index d97e535b23..db7ea39d0d 100644 --- a/src/google/adk/flows/llm_flows/basic.py +++ b/src/google/adk/flows/llm_flows/basic.py @@ -17,7 +17,6 @@ from __future__ import annotations from typing import AsyncGenerator -from typing import Generator from google.genai import types from typing_extensions import override @@ -28,7 +27,6 @@ from ...utils.output_schema_utils import can_use_output_schema_with_tools from ._base_llm_processor import BaseLlmRequestProcessor - class _BasicLlmRequestProcessor(BaseLlmRequestProcessor): @override @@ -38,11 +36,42 @@ async def run_async( agent = invocation_context.agent model = agent.canonical_model llm_request.model = model if isinstance(model, str) else model.model + + # Preserve http_options propagated from RunConfig + run_config_http_options = llm_request.config.http_options + llm_request.config = ( agent.generate_content_config.model_copy(deep=True) if agent.generate_content_config else types.GenerateContentConfig() ) + + if run_config_http_options: + # Merge RunConfig http_options back, overriding agent config + if not llm_request.config.http_options: + llm_request.config.http_options = run_config_http_options + else: + # Merge headers + if run_config_http_options.headers: + if not llm_request.config.http_options.headers: + llm_request.config.http_options.headers = {} + llm_request.config.http_options.headers.update( + run_config_http_options.headers + ) + + # Merge other http_options fields if present in RunConfig. + # RunConfig values override agent defaults. + # Note: base_url, api_version, base_url_resource_scope are intentionally + # excluded as they are configuration-time settings, not request-time. + for field in [ + 'timeout', + 'retry_options', + 'extra_body', + ]: + val = getattr(run_config_http_options, field, None) + if val is not None: + setattr(llm_request.config.http_options, field, val) + # Only set output_schema if no tools are specified. as of now, model don't # support output_schema and tools together. we have a workaround to support # both output_schema and tools at the same time. see @@ -84,5 +113,4 @@ async def run_async( return yield # Generator requires yield statement in function body. - request_processor = _BasicLlmRequestProcessor() diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 79182d7b0a..bdb2d9f355 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -148,7 +148,6 @@ "completion", ) - def _ensure_litellm_imported() -> None: """Imports LiteLLM with safe defaults. @@ -175,7 +174,6 @@ def _ensure_litellm_imported() -> None: _redirect_litellm_loggers_to_stdout() _LITELLM_IMPORTED = True - def _map_finish_reason( finish_reason: Any, ) -> types.FinishReason | None: @@ -187,7 +185,6 @@ def _map_finish_reason( finish_reason_str = str(finish_reason).lower() return _FINISH_REASON_MAPPING.get(finish_reason_str, types.FinishReason.OTHER) - def _get_provider_from_model(model: str) -> str: """Extracts the provider name from a LiteLLM model string. @@ -213,11 +210,9 @@ def _get_provider_from_model(model: str) -> str: return "openai" return "" - # Default MIME type when none can be inferred _DEFAULT_MIME_TYPE = "application/octet-stream" - def _infer_mime_type_from_uri(uri: str) -> Optional[str]: """Attempts to infer MIME type from a URI's path extension. @@ -258,12 +253,10 @@ def _infer_mime_type_from_uri(uri: str) -> Optional[str]: logger.debug("Could not infer MIME type from URI %s: %s", uri, e) return None - def _looks_like_openai_file_id(file_uri: str) -> bool: """Returns True when file_uri resembles an OpenAI/Azure file id.""" return file_uri.startswith("file-") - def _redact_file_uri_for_log( file_uri: str, *, display_name: str | None = None ) -> str: @@ -284,7 +277,6 @@ def _redact_file_uri_for_log( return f"{parsed.scheme}:///{tail}" return f"{parsed.scheme}://" - def _requires_file_uri_fallback( provider: str, model: str, file_uri: str ) -> bool: @@ -297,7 +289,6 @@ def _requires_file_uri_fallback( return True return False - def _decode_inline_text_data(raw_bytes: bytes) -> str: """Decodes inline file bytes that represent textual content.""" try: @@ -306,7 +297,6 @@ def _decode_inline_text_data(raw_bytes: bytes) -> str: logger.debug("Falling back to latin-1 decoding for inline file bytes.") return raw_bytes.decode("latin-1", errors="replace") - def _iter_reasoning_texts(reasoning_value: Any) -> Iterable[str]: """Yields textual fragments from provider specific reasoning payloads.""" if reasoning_value is None: @@ -344,7 +334,6 @@ def _iter_reasoning_texts(reasoning_value: Any) -> Iterable[str]: elif isinstance(reasoning_value, (int, float, bool)): yield str(reasoning_value) - def _convert_reasoning_value_to_parts(reasoning_value: Any) -> List[types.Part]: """Converts provider reasoning payloads into Gemini thought parts.""" return [ @@ -353,7 +342,6 @@ def _convert_reasoning_value_to_parts(reasoning_value: Any) -> List[types.Part]: if text ] - def _extract_reasoning_value(message: Message | Dict[str, Any]) -> Any: """Fetches the reasoning payload from a LiteLLM message or dict.""" if message is None: @@ -364,35 +352,29 @@ def _extract_reasoning_value(message: Message | Dict[str, Any]) -> Any: return message.get("reasoning_content") return None - class ChatCompletionFileUrlObject(TypedDict, total=False): file_data: str file_id: str format: str - class FunctionChunk(BaseModel): id: Optional[str] name: Optional[str] args: Optional[str] index: Optional[int] = 0 - class TextChunk(BaseModel): text: str - class ReasoningChunk(BaseModel): parts: List[types.Part] - class UsageMetadataChunk(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int cached_prompt_tokens: int = 0 - class LiteLLMClient: """Provides acompletion method (for better testability).""" @@ -444,7 +426,6 @@ def completion( **kwargs, ) - def _safe_json_serialize(obj) -> str: """Convert any Python object to a JSON-serializable type or string. @@ -461,7 +442,6 @@ def _safe_json_serialize(obj) -> str: except (TypeError, OverflowError): return str(obj) - def _part_has_payload(part: types.Part) -> bool: """Checks whether a Part contains usable payload for the model.""" if part.text: @@ -472,7 +452,6 @@ def _part_has_payload(part: types.Part) -> bool: return True return False - def _append_fallback_user_content_if_missing( llm_request: LlmRequest, ) -> None: @@ -508,7 +487,6 @@ def _append_fallback_user_content_if_missing( ) ) - def _extract_cached_prompt_tokens(usage: Any) -> int: """Extracts cached prompt tokens from LiteLLM usage. @@ -561,7 +539,6 @@ def _extract_cached_prompt_tokens(usage: Any) -> int: return 0 - async def _content_to_message_param( content: types.Content, *, @@ -680,7 +657,6 @@ async def _content_to_message_param( reasoning_content=reasoning_content or None, ) - def _ensure_tool_results(messages: List[Message]) -> List[Message]: """Insert placeholder tool messages for missing tool results. @@ -741,7 +717,6 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]: return healed_messages - async def _get_content( parts: Iterable[types.Part], *, @@ -891,7 +866,6 @@ async def _get_content( return content_objects - def _is_ollama_chat_provider( model: Optional[str], custom_llm_provider: Optional[str] ) -> bool: @@ -905,7 +879,6 @@ def _is_ollama_chat_provider( return True return False - def _flatten_ollama_content( content: OpenAIMessageContent | str | None, ) -> str | None: @@ -946,7 +919,6 @@ def _flatten_ollama_content( except TypeError: return str(blocks) - def _normalize_ollama_chat_messages( messages: list[Message], *, @@ -992,7 +964,6 @@ def _normalize_ollama_chat_messages( return normalized_messages - def _build_tool_call_from_json_dict( candidate: Any, *, index: int ) -> Optional[ChatCompletionMessageToolCall]: @@ -1040,7 +1011,6 @@ def _build_tool_call_from_json_dict( return tool_call - def _parse_tool_calls_from_text( text_block: str, ) -> tuple[list[ChatCompletionMessageToolCall], Optional[str]]: @@ -1083,7 +1053,6 @@ def _parse_tool_calls_from_text( return tool_calls, remainder or None - def _split_message_content_and_tool_calls( message: Message, ) -> tuple[Optional[OpenAIMessageContent], list[ChatCompletionMessageToolCall]]: @@ -1105,7 +1074,6 @@ def _split_message_content_and_tool_calls( return content, [] - def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]: """Converts a types.Content role to a litellm role. @@ -1120,7 +1088,6 @@ def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]: return "assistant" return "user" - TYPE_LABELS = { "STRING": "string", "NUMBER": "number", @@ -1130,7 +1097,6 @@ def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]: "INTEGER": "integer", } - def _schema_to_dict(schema: types.Schema | dict[str, Any]) -> dict: """Recursively converts a schema object or dict to a pure-python dict. @@ -1174,7 +1140,6 @@ def _schema_to_dict(schema: types.Schema | dict[str, Any]) -> dict: return schema_dict - def _function_declaration_to_tool_param( function_declaration: types.FunctionDeclaration, ) -> dict: @@ -1227,7 +1192,6 @@ def _function_declaration_to_tool_param( return tool_params - def _model_response_to_chunk( response: ModelResponse, ) -> Generator[ @@ -1317,7 +1281,6 @@ def _model_response_to_chunk( cached_prompt_tokens=_extract_cached_prompt_tokens(response["usage"]), ), None - def _model_response_to_generate_content_response( response: ModelResponse, ) -> LlmResponse: @@ -1370,7 +1333,6 @@ def _model_response_to_generate_content_response( ) return llm_response - def _message_to_generate_content_response( message: Message, *, @@ -1417,7 +1379,6 @@ def _message_to_generate_content_response( model_version=model_version, ) - def _to_litellm_response_format( response_schema: types.SchemaUnion, model: str, @@ -1496,7 +1457,6 @@ def _to_litellm_response_format( }, } - async def _get_completion_inputs( llm_request: LlmRequest, model: str, @@ -1591,7 +1551,6 @@ async def _get_completion_inputs( return messages, tools, response_format, generation_params - def _build_function_declaration_log( func_decl: types.FunctionDeclaration, ) -> str: @@ -1615,7 +1574,6 @@ def _build_function_declaration_log( return_str = str(func_decl.response.model_dump(exclude_none=True)) return f"{func_decl.name}: {param_str} -> {return_str}" - def _build_request_log(req: LlmRequest) -> str: """Builds a request log. @@ -1664,7 +1622,6 @@ def _build_request_log(req: LlmRequest) -> str: ----------------------------------------------------------- """ - def _is_litellm_gemini_model(model_string: str) -> bool: """Check if the model is a Gemini model accessed via LiteLLM. @@ -1677,7 +1634,6 @@ def _is_litellm_gemini_model(model_string: str) -> bool: """ return model_string.startswith(("gemini/gemini-", "vertex_ai/gemini-")) - def _extract_gemini_model_from_litellm(litellm_model: str) -> str: """Extract the pure Gemini model name from a LiteLLM model string. @@ -1692,7 +1648,6 @@ def _extract_gemini_model_from_litellm(litellm_model: str) -> str: return litellm_model.split("/", 1)[1] return litellm_model - def _warn_gemini_via_litellm(model_string: str) -> None: """Warn if Gemini is being used via LiteLLM. @@ -1723,7 +1678,6 @@ def _warn_gemini_via_litellm(model_string: str) -> None: stacklevel=3, ) - def _redirect_litellm_loggers_to_stdout() -> None: """Redirects LiteLLM loggers from stderr to stdout. @@ -1742,7 +1696,6 @@ def _redirect_litellm_loggers_to_stdout() -> None: ): handler.stream = sys.stdout - class LiteLlm(BaseLlm): """Wrapper around litellm. @@ -1836,6 +1789,18 @@ async def generate_content_async( if generation_params: completion_args.update(generation_params) + if ( + llm_request.config.http_options + and llm_request.config.http_options.headers + ): + extra_headers = completion_args.get("extra_headers", {}) + if isinstance(extra_headers, dict): + extra_headers = extra_headers.copy() + else: + extra_headers = {} + extra_headers.update(llm_request.config.http_options.headers) + completion_args["extra_headers"] = extra_headers + if stream: text = "" reasoning_parts: List[types.Part] = [] diff --git a/tests/unittests/flows/llm_flows/test_basic_processor.py b/tests/unittests/flows/llm_flows/test_basic_processor.py index af0ccfe0b1..60896b850b 100644 --- a/tests/unittests/flows/llm_flows/test_basic_processor.py +++ b/tests/unittests/flows/llm_flows/test_basic_processor.py @@ -188,3 +188,50 @@ async def test_sets_model_name(self): # Should have set the model name assert llm_request.model == 'gemini-1.5-flash' + + @pytest.mark.asyncio + async def test_preserves_merged_http_options(self): + """Test that processor preserves and merges existing http_options.""" + from google.genai import types + + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + generate_content_config=types.GenerateContentConfig( + http_options=types.HttpOptions( + timeout=1000, + headers={'Agent-Header': 'agent-val'}, + ) + ) + ) + + invocation_context = await _create_invocation_context(agent) + llm_request = LlmRequest() + + # Simulate http_options propagated from RunConfig + llm_request.config.http_options = types.HttpOptions( + timeout=500, # Should override agent + headers={ + 'RunConfig-Header': 'run-val', + 'Agent-Header': 'run-val-override' + } + ) + + processor = _BasicLlmRequestProcessor() + + # Process the request + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + # Verify timeout from RunConfig wins + assert llm_request.config.http_options.timeout == 500 + + # Verify headers merged, RunConfig wins + assert ( + llm_request.config.http_options.headers['RunConfig-Header'] == 'run-val' + ) + assert ( + llm_request.config.http_options.headers['Agent-Header'] + == 'run-val-override' + ) diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 2ebbc5dfe8..10404e17e6 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -3793,3 +3793,100 @@ def test_handles_litellm_logger_names(logger_name): finally: # Clean up test_logger.removeHandler(handler) + + +@pytest.mark.asyncio +async def test_generate_content_async_passes_http_options_headers_as_extra_headers( + mock_acompletion, lite_llm_instance +): + """Test that http_options.headers from LlmRequest are forwarded to litellm.""" + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Test prompt")] + ) + ], + config=types.GenerateContentConfig( + http_options=types.HttpOptions( + headers={"X-User-Id": "user-123", "X-Trace-Id": "trace-abc"} + ) + ), + ) + + async for _ in lite_llm_instance.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + assert "extra_headers" in kwargs + assert kwargs["extra_headers"]["X-User-Id"] == "user-123" + assert kwargs["extra_headers"]["X-Trace-Id"] == "trace-abc" + + +@pytest.mark.asyncio +async def test_generate_content_async_merges_http_options_with_existing_extra_headers( + mock_response, +): + """Test that http_options.headers merge with pre-existing extra_headers.""" + mock_acompletion = AsyncMock(return_value=mock_response) + mock_client = MockLLMClient(mock_acompletion, Mock()) + # Create instance with pre-existing extra_headers via kwargs + lite_llm_with_extra = LiteLlm( + model="test_model", + llm_client=mock_client, + extra_headers={"X-Api-Key": "secret-key"}, + ) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Test prompt")] + ) + ], + config=types.GenerateContentConfig( + http_options=types.HttpOptions(headers={"X-User-Id": "user-456"}) + ), + ) + + async for _ in lite_llm_with_extra.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + assert "extra_headers" in kwargs + # Both existing and new headers should be present + assert kwargs["extra_headers"]["X-Api-Key"] == "secret-key" + assert kwargs["extra_headers"]["X-User-Id"] == "user-456" + + +@pytest.mark.asyncio +async def test_generate_content_async_http_options_headers_override_existing( + mock_response, +): + """Test that http_options.headers override same-key extra_headers from init.""" + mock_acompletion = AsyncMock(return_value=mock_response) + mock_client = MockLLMClient(mock_acompletion, Mock()) + lite_llm_with_extra = LiteLlm( + model="test_model", + llm_client=mock_client, + extra_headers={"X-Override-Me": "old-value"}, + ) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Test prompt")] + ) + ], + config=types.GenerateContentConfig( + http_options=types.HttpOptions(headers={"X-Override-Me": "new-value"}) + ), + ) + + async for _ in lite_llm_with_extra.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + # Request-level headers should override init-level headers + assert kwargs["extra_headers"]["X-Override-Me"] == "new-value"