Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ def _part_has_payload(part: types.Part) -> bool:
return True
if part.file_data and (part.file_data.file_uri or part.file_data.data):
return True
if part.function_response:
return True
return False


Expand Down
41 changes: 41 additions & 0 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,47 @@ async def test_generate_content_async_adds_fallback_user_message(
)


@pytest.mark.asyncio
async def test_generate_content_async_no_fallback_for_function_response(
mock_acompletion, lite_llm_instance
):
"""Tests that no fallback message is added for a user message with a function response."""
llm_request = LlmRequest(
contents=[
types.Content(
role="user",
parts=[
types.Part.from_function_response(
name="test_function",
response={"result": "test_result"},
)
],
)
]
)

# Run generate_content_async which calls _append_fallback_user_content_if_missing
async for _ in lite_llm_instance.generate_content_async(llm_request):
pass

# Verify that the fallback message was NOT added to the llm_request
assert len(llm_request.contents) == 1
assert len(llm_request.contents[0].parts) == 1
assert llm_request.contents[0].parts[0].function_response is not None

# Verify that the message sent to litellm does not contain the fallback text
mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
user_messages = [
message for message in kwargs["messages"] if message["role"] == "user"
]
assert not any(
message.get("content")
== "Handle the requests as specified in the System Instruction."
for message in user_messages
)


litellm_append_user_content_test_cases = [
pytest.param(
LlmRequest(
Expand Down
Loading