Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion py/packages/genkit/src/genkit/ai/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,21 @@ def _initialize_server(self, reflection_server_spec: ServerSpec | None) -> None:
def _join(self):
"""Block until Genkit internal threads are closed. Only blocking in dev mode."""
if is_dev_environment() and self._thread:
self._thread.join()
if self._loop:
if not self._loop.is_running():
logger.info('Starting main thread event loop to handle background tasks')
try:
self._loop.run_forever()
except KeyboardInterrupt:
logger.info('Main thread event loop interrupted')
finally:
if self._loop.is_running():
self._loop.stop()
else:
logger.warning('Event loop already running in _join, falling back to thread join')
self._thread.join()
else:
self._thread.join()

def _start_server(self, spec: ServerSpec, loop: asyncio.AbstractEventLoop) -> None:
"""Start the HTTP server for handling requests.
Expand Down
2 changes: 1 addition & 1 deletion py/packages/genkit/src/genkit/aio/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def run_fn() -> Any:
try:
output = await fn()
return output
except Exception as e:
except BaseException as e:
error = e
finally:
lock.release()
Expand Down
4 changes: 3 additions & 1 deletion py/packages/genkit/src/genkit/core/reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def _build_actions_payload(
actions[key] = {
'key': key,
'name': action.name,
'kind': action.kind.value, # Add 'kind' field for Dev UI
'type': action.kind.value,
'description': action.description,
'inputSchema': action.input_schema,
Expand All @@ -120,6 +121,7 @@ def _build_actions_payload(
advertised = {
'key': key,
'name': meta.name,
'kind': meta.kind.value, # Add 'kind' field for Dev UI
'type': meta.kind.value,
'description': getattr(meta, 'description', None),
'inputSchema': getattr(meta, 'input_json_schema', None),
Expand Down Expand Up @@ -694,7 +696,7 @@ def wrapped_on_trace_start(tid):
status_code=200,
headers={'x-genkit-version': version},
)
except Exception as e:
except BaseException as e:
error_response = get_reflection_json(e).model_dump(by_alias=True)
logger.error('Error executing action', error=error_response)
return JSONResponse(
Expand Down
18 changes: 18 additions & 0 deletions py/packages/genkit/src/genkit/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,24 @@ async def list_actions(self, allowed_kinds: list[ActionKind] | None = None) -> l
if allowed_kinds and meta.kind not in allowed_kinds:
continue
metas.append(meta)

# Include actions registered directly in the registry
with self._lock:
for kind, kind_map in self._entries.items():
if allowed_kinds and kind not in allowed_kinds:
continue
for action in kind_map.values():
metas.append(
ActionMetadata(
kind=action.kind,
name=action.name,
description=action.description,
input_json_schema=action.input_schema,
output_json_schema=action.output_schema,
metadata=action.metadata,
)
)

return metas

def register_schema(self, name: str, schema: dict[str, Any]) -> None:
Expand Down
4 changes: 3 additions & 1 deletion py/packages/genkit/tests/genkit/core/registry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ async def list_actions(self) -> list[ActionMetadata]:
ai = Genkit(plugins=[MyPlugin()])

metas = await ai.registry.list_actions()
assert metas == [ActionMetadata(kind=ActionKind.MODEL, name='myplugin/foo')]
# Filter for the specific plugin action we expect, ignoring system actions like 'generate'
target_meta = next((m for m in metas if m.name == 'myplugin/foo'), None)
assert target_meta == ActionMetadata(kind=ActionKind.MODEL, name='myplugin/foo')

action = await ai.registry.resolve_action(ActionKind.MODEL, 'myplugin/foo')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ def _create_tool(self, tool: ToolDefinition) -> genai_types.Tool:
params = genai_types.Schema(type=genai_types.Type.OBJECT, properties={})

function = genai_types.FunctionDeclaration(
name=tool.name,
name=tool.name.replace('/', '__'),
description=tool.description,
parameters=params,
response=self._convert_schema_property(tool.output_schema) if tool.output_schema else None,
Expand Down Expand Up @@ -911,9 +911,18 @@ async def _generate(
),
)
client = client or self._client
response = await client.aio.models.generate_content(
model=model_name, contents=request_contents, config=request_cfg
)
try:
import structlog

logger = structlog.get_logger()
logger.debug('Gemini: calling generate_content', model=model_name)
response = await client.aio.models.generate_content(
model=model_name, contents=request_contents, config=request_cfg
)
logger.debug('Gemini: received response')
except Exception as e:
logger.error('Gemini: generate_content failed', error=str(e))
raise
span.set_attribute('genkit:output', dump_json(response))

content = self._contents_from_response(response)
Expand Down
2 changes: 1 addition & 1 deletion py/plugins/mcp/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ build-backend = "hatchling.build"
requires = ["hatchling"]

[tool.hatch.build.targets.wheel]
packages = ["src"]
packages = ["src/genkit"]
139 changes: 65 additions & 74 deletions py/plugins/mcp/src/genkit/plugins/mcp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
#
# SPDX-License-Identifier: Apache-2.0

import asyncio
from contextlib import AsyncExitStack
from typing import Any

import structlog
from pydantic import BaseModel

from genkit.ai import Genkit, Plugin
from genkit.core.action import Action, ActionMetadata
from genkit.ai import Genkit
from genkit.ai._registry import GenkitRegistry
from genkit.core.action.types import ActionKind
from genkit.core.plugin import Plugin
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
Expand All @@ -38,57 +41,33 @@ class McpServerConfig(BaseModel):
disabled: bool = False


class McpClient(Plugin):
class McpClient:
"""Client for connecting to a single MCP server."""

def __init__(self, name: str, config: McpServerConfig, server_name: str | None = None):
self.name = name
self.config = config
self.server_name = server_name or name
self.session: ClientSession | None = None
self._exit_stack = None
self._session_context = None
self.ai: Genkit | None = None
self._exit_stack = AsyncExitStack()
self.ai: GenkitRegistry | None = None

def plugin_name(self) -> str:
return self.name

async def init(self) -> list[Action]:
"""Initialize MCP plugin.
def initialize(self, ai: GenkitRegistry) -> None:
self.ai = ai

MCP tools are registered dynamically upon connection, so this returns an empty list.
Returns:
Empty list (tools are registered dynamically).
"""
return []

async def resolve(self, action_type: ActionKind, name: str) -> Action | None:
"""Resolve an action by name.
MCP uses dynamic registration, so this returns None.
Args:
action_type: The kind of action to resolve.
name: The namespaced name of the action to resolve.
Returns:
None (MCP uses dynamic registration).
"""
return None

async def list_actions(self) -> list[ActionMetadata]:
"""List available MCP actions.
MCP tools are discovered at runtime, so this returns an empty list.
Returns:
Empty list (tools are discovered at runtime).
"""
return []
def resolve_action(self, ai: GenkitRegistry, kind: ActionKind, name: str) -> None:
# MCP tools are dynamic and currently registered upon connection/Discovery.
# This hook allows lazy resolution if we implement it.
pass

async def connect(self):
"""Connects to the MCP server."""
if self.session:
return

if self.config.disabled:
logger.info(f'MCP server {self.server_name} is disabled.')
return
Expand All @@ -100,23 +79,19 @@ async def connect(self):
)
# stdio_client returns (read, write) streams
stdio_context = stdio_client(server_params)
read, write = await stdio_context.__aenter__()
self._exit_stack = stdio_context
read, write = await self._exit_stack.enter_async_context(stdio_context)

# Create and initialize session
session_context = ClientSession(read, write)
self.session = await session_context.__aenter__()
self._session_context = session_context
self.session = await self._exit_stack.enter_async_context(session_context)

elif self.config.url:
# TODO: Verify SSE client usage in mcp python SDK
sse_context = sse_client(self.config.url)
read, write = await sse_context.__aenter__()
self._exit_stack = sse_context
read, write = await self._exit_stack.enter_async_context(sse_context)

session_context = ClientSession(read, write)
self.session = await session_context.__aenter__()
self._session_context = session_context
self.session = await self._exit_stack.enter_async_context(session_context)

await self.session.initialize()
logger.info(f'Connected to MCP server: {self.server_name}')
Expand All @@ -130,16 +105,16 @@ async def connect(self):

async def close(self):
"""Closes the connection."""
if hasattr(self, '_session_context') and self._session_context:
try:
await self._session_context.__aexit__(None, None, None)
except Exception as e:
logger.debug(f'Error closing session: {e}')
if self._exit_stack:
try:
await self._exit_stack.__aexit__(None, None, None)
except Exception as e:
logger.debug(f'Error closing transport: {e}')
await self._exit_stack.aclose()
except (Exception, asyncio.CancelledError):
# Ignore errors during cleanup, especially cancellation from anyio
pass

# Reset exit stack for potential reuse (reconnect)
self._exit_stack = AsyncExitStack()
self.session = None

async def list_tools(self) -> list[Tool]:
if not self.session:
Expand All @@ -150,14 +125,21 @@ async def list_tools(self) -> list[Tool]:
async def call_tool(self, tool_name: str, arguments: dict) -> Any:
if not self.session:
raise RuntimeError('MCP client is not connected')
result: CallToolResult = await self.session.call_tool(tool_name, arguments)
# Process result similarly to JS SDK
if result.isError:
raise RuntimeError(f'Tool execution failed: {result.content}')
logger.debug(f'MCP {self.server_name}: calling tool {tool_name}', arguments=arguments)
try:
result: CallToolResult = await self.session.call_tool(tool_name, arguments)
logger.debug(f'MCP {self.server_name}: tool {tool_name} returned')

# Simple text extraction for now
texts = [c.text for c in result.content if c.type == 'text']
return ''.join(texts)
# Process result similarly to JS SDK
if result.isError:
raise RuntimeError(f'Tool execution failed: {result.content}')

# Simple text extraction for now
texts = [c.text for c in result.content if c.type == 'text']
return {'content': ''.join(texts)}
except Exception as e:
logger.error(f'MCP {self.server_name}: tool {tool_name} failed', error=str(e))
raise

async def list_prompts(self) -> list[Prompt]:
if not self.session:
Expand Down Expand Up @@ -194,29 +176,38 @@ async def register_tools(self, ai: Genkit | None = None):
try:
tools = await self.list_tools()
for tool in tools:
# Create a wrapper function for the tool
# We need to capture tool and client in closure
async def tool_wrapper(args: Any = None, _tool_name=tool.name):
# args might be Pydantic model or dict. Genkit passes dict usually?
# TODO: Validate args against schema if needed
arguments = args
if hasattr(args, 'model_dump'):
arguments = args.model_dump()
return await self.call_tool(_tool_name, arguments or {})
# Create a wrapper function for the tool using a factory to capture tool name
def create_wrapper(tool_name: str):
async def tool_wrapper(args: Any = None):
# args might be Pydantic model or dict. Genkit passes dict usually?
# TODO: Validate args against schema if needed
arguments = args
if hasattr(args, 'model_dump'):
arguments = args.model_dump()
return await self.call_tool(tool_name, arguments or {})

return tool_wrapper

tool_wrapper = create_wrapper(tool.name)

# Use metadata to store MCP specific info
metadata = {'mcp': {'_meta': tool._meta}} if hasattr(tool, '_meta') else {}

# Define the tool in Genkit registry
registry.register_action(
action = registry.register_action(
kind=ActionKind.TOOL,
name=f'{self.server_name}/{tool.name}',
name=f'{self.server_name}_{tool.name}',
fn=tool_wrapper,
description=tool.description,
metadata=metadata,
# TODO: json_schema conversion from tool.inputSchema
)
logger.debug(f'Registered MCP tool: {self.server_name}/{tool.name}')

# Patch input schema from MCP tool definition
if tool.inputSchema:
action._input_schema = tool.inputSchema
action._metadata['inputSchema'] = tool.inputSchema

logger.debug(f'Registered MCP tool: {self.server_name}_{tool.name}')
except Exception as e:
logger.error(f'Error registering tools for {self.server_name}: {e}')

Expand Down
Loading
Loading