From b019504b7bcbef603c37528fe28a17fce9bf9ab0 Mon Sep 17 00:00:00 2001 From: Benjamin Kluger Date: Fri, 30 Jan 2026 20:24:18 -0500 Subject: [PATCH] feat: Add voice agent SDK module and CLI template Add comprehensive voice agent support to the Agentex SDK: ## New SDK Module (src/agentex/voice/) - VoiceAgentBase: Base class with state management, interruption handling, guardrail execution, and streaming support - AgentState/AgentResponse: Pydantic models for conversation state - Guardrail/LLMGuardrail: Abstract base classes for implementing guardrails ## New CLI Template (agentex init --voice) - Full project scaffolding for voice agents - Multi-provider LLM support (OpenAI, Azure, SGP, Vertex AI, Mock) - Docker + Kubernetes deployment configuration - Example guardrail structure ## CLI Changes - Added hidden --voice flag to 'agentex init' command - Generates voice-specific project structure ## Dependencies - Added partial-json-parser for streaming JSON parsing - Added google-auth for Vertex AI authentication --- pyproject.toml | 2 + src/agentex/lib/cli/commands/init.py | 85 +- .../lib/cli/templates/voice/.dockerignore.j2 | 43 + .../lib/cli/templates/voice/Dockerfile-uv.j2 | 42 + .../lib/cli/templates/voice/Dockerfile.j2 | 43 + .../lib/cli/templates/voice/README.md.j2 | 399 +++++++++ .../lib/cli/templates/voice/dev.ipynb.j2 | 167 ++++ .../cli/templates/voice/environments.yaml.j2 | 37 + .../lib/cli/templates/voice/manifest.yaml.j2 | 123 +++ .../lib/cli/templates/voice/project/acp.py.j2 | 360 ++++++++ .../lib/cli/templates/voice/pyproject.toml.j2 | 36 + .../cli/templates/voice/requirements.txt.j2 | 9 + .../lib/cli/templates/voice/test_agent.py.j2 | 70 ++ src/agentex/voice/__init__.py | 31 + src/agentex/voice/agent.py | 788 ++++++++++++++++++ src/agentex/voice/guardrails.py | 304 +++++++ src/agentex/voice/models.py | 53 ++ src/agentex/voice/py.typed | 0 uv.lock | 19 +- 19 files changed, 2576 insertions(+), 35 deletions(-) create mode 100644 src/agentex/lib/cli/templates/voice/.dockerignore.j2 create mode 100644 src/agentex/lib/cli/templates/voice/Dockerfile-uv.j2 create mode 100644 src/agentex/lib/cli/templates/voice/Dockerfile.j2 create mode 100644 src/agentex/lib/cli/templates/voice/README.md.j2 create mode 100644 src/agentex/lib/cli/templates/voice/dev.ipynb.j2 create mode 100644 src/agentex/lib/cli/templates/voice/environments.yaml.j2 create mode 100644 src/agentex/lib/cli/templates/voice/manifest.yaml.j2 create mode 100644 src/agentex/lib/cli/templates/voice/project/acp.py.j2 create mode 100644 src/agentex/lib/cli/templates/voice/pyproject.toml.j2 create mode 100644 src/agentex/lib/cli/templates/voice/requirements.txt.j2 create mode 100644 src/agentex/lib/cli/templates/voice/test_agent.py.j2 create mode 100644 src/agentex/voice/__init__.py create mode 100644 src/agentex/voice/agent.py create mode 100644 src/agentex/voice/guardrails.py create mode 100644 src/agentex/voice/models.py create mode 100644 src/agentex/voice/py.typed diff --git a/pyproject.toml b/pyproject.toml index bf518fb94..04fcfe4ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,8 @@ dependencies = [ "yaspin>=3.1.0", "claude-agent-sdk>=0.1.0", "anthropic>=0.40.0", + "partial-json-parser>=0.2.1", # For voice agent streaming JSON parsing + "google-auth>=2.0.0", # For Vertex AI authentication in voice agents ] requires-python = ">= 3.12,<4" diff --git a/src/agentex/lib/cli/commands/init.py b/src/agentex/lib/cli/commands/init.py index 1654f48c5..8c0485a96 100644 --- a/src/agentex/lib/cli/commands/init.py +++ b/src/agentex/lib/cli/commands/init.py @@ -1,9 +1,10 @@ from __future__ import annotations from enum import Enum -from typing import Any, Dict +from typing import Any, Dict, Optional from pathlib import Path +import typer import questionary from jinja2 import Environment, FileSystemLoader from rich.rule import Rule @@ -27,6 +28,7 @@ class TemplateType(str, Enum): DEFAULT = "default" SYNC = "sync" SYNC_OPENAI_AGENTS = "sync-openai-agents" + VOICE = "voice" def render_template( @@ -60,6 +62,7 @@ def create_project_structure( TemplateType.DEFAULT: ["acp.py"], TemplateType.SYNC: ["acp.py"], TemplateType.SYNC_OPENAI_AGENTS: ["acp.py"], + TemplateType.VOICE: ["acp.py"], }[template_type] # Create project/code files @@ -102,9 +105,15 @@ def get_project_context(answers: Dict[str, Any], project_path: Path, manifest_ro # Now, this is actually the exact same as the project_name because we changed the build root to be ../ project_path_from_build_root = project_name + # Create PascalCase class name from agent name + agent_class_name = "".join( + word.capitalize() for word in answers["agent_name"].split("-") + ) + return { **answers, "project_name": project_name, + "agent_class_name": agent_class_name, "workflow_class": "".join( word.capitalize() for word in answers["agent_name"].split("-") ) @@ -115,7 +124,14 @@ def get_project_context(answers: Dict[str, Any], project_path: Path, manifest_ro } -def init(): +def init( + voice: bool = typer.Option( + False, + "--voice", + hidden=True, + help="Create a voice agent template (LiveKit + Gemini)", + ), +): """Initialize a new agent project""" console.print( Panel.fit( @@ -124,25 +140,40 @@ def init(): ) ) - # Use a Rich table for template descriptions - table = Table(show_header=True, header_style="bold blue") - table.add_column("Template", style="cyan", no_wrap=True) - table.add_column("Description", style="white") - table.add_row( - "[bold cyan]Async - ACP Only[/bold cyan]", - "Asynchronous, non-blocking agent that can process multiple concurrent requests. Best for straightforward asynchronous agents that don't need durable execution. Good for asynchronous workflows, stateful applications, and multi-step analysis.", - ) - table.add_row( - "[bold cyan]Async - Temporal[/bold cyan]", - "Asynchronous, non-blocking agent with durable execution for all steps. Best for production-grade agents that require complex multi-step tool calls, human-in-the-loop approvals, and long-running processes that require transactional reliability.", - ) - table.add_row( - "[bold cyan]Sync ACP[/bold cyan]", - "Synchronous agent that processes one request per task with a simple request-response pattern. Best for low-latency use cases, FAQ bots, translation services, and data lookups.", - ) - console.print() - console.print(table) - console.print() + # If --voice flag is passed, skip the menu and use voice template + if voice: + console.print("[bold cyan]Creating Voice Agent template...[/bold cyan]\n") + template_type = TemplateType.VOICE + else: + # Use a Rich table for template descriptions + table = Table(show_header=True, header_style="bold blue") + table.add_column("Template", style="cyan", no_wrap=True) + table.add_column("Description", style="white") + table.add_row( + "[bold cyan]Async - ACP Only[/bold cyan]", + "Asynchronous, non-blocking agent that can process multiple concurrent requests. Best for straightforward asynchronous agents that don't need durable execution. Good for asynchronous workflows, stateful applications, and multi-step analysis.", + ) + table.add_row( + "[bold cyan]Async - Temporal[/bold cyan]", + "Asynchronous, non-blocking agent with durable execution for all steps. Best for production-grade agents that require complex multi-step tool calls, human-in-the-loop approvals, and long-running processes that require transactional reliability.", + ) + table.add_row( + "[bold cyan]Sync ACP[/bold cyan]", + "Synchronous agent that processes one request per task with a simple request-response pattern. Best for low-latency use cases, FAQ bots, translation services, and data lookups.", + ) + console.print() + console.print(table) + console.print() + + # Gather project information + template_type = questionary.select( + "What type of template would you like to create?", + choices=[ + {"name": "Async - ACP Only", "value": TemplateType.DEFAULT}, + {"name": "Async - Temporal", "value": "temporal_submenu"}, + {"name": "Sync ACP", "value": "sync_submenu"}, + ], + ).ask() def validate_agent_name(text: str) -> bool | str: """Validate agent name follows required format""" @@ -150,17 +181,7 @@ def validate_agent_name(text: str) -> bool | str: if not is_valid: return "Invalid name. Use only lowercase letters, numbers, and hyphens. Examples: 'my-agent', 'newsbot'" return True - - # Gather project information - template_type = questionary.select( - "What type of template would you like to create?", - choices=[ - {"name": "Async - ACP Only", "value": TemplateType.DEFAULT}, - {"name": "Async - Temporal", "value": "temporal_submenu"}, - {"name": "Sync ACP", "value": "sync_submenu"}, - ], - ).ask() - if not template_type: + if template_type is None: return # If Temporal was selected, show sub-menu for Temporal variants diff --git a/src/agentex/lib/cli/templates/voice/.dockerignore.j2 b/src/agentex/lib/cli/templates/voice/.dockerignore.j2 new file mode 100644 index 000000000..c2d7fca4d --- /dev/null +++ b/src/agentex/lib/cli/templates/voice/.dockerignore.j2 @@ -0,0 +1,43 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Environments +.env** +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Git +.git +.gitignore + +# Misc +.DS_Store diff --git a/src/agentex/lib/cli/templates/voice/Dockerfile-uv.j2 b/src/agentex/lib/cli/templates/voice/Dockerfile-uv.j2 new file mode 100644 index 000000000..2ac5be7d2 --- /dev/null +++ b/src/agentex/lib/cli/templates/voice/Dockerfile-uv.j2 @@ -0,0 +1,42 @@ +# syntax=docker/dockerfile:1.3 +FROM python:3.12-slim +COPY --from=ghcr.io/astral-sh/uv:0.6.4 /uv /uvx /bin/ + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + htop \ + vim \ + curl \ + tar \ + python3-dev \ + postgresql-client \ + build-essential \ + libpq-dev \ + gcc \ + cmake \ + netcat-openbsd \ + nodejs \ + npm \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/** + +RUN uv pip install --system --upgrade pip setuptools wheel + +ENV UV_HTTP_TIMEOUT=1000 + +# Copy just the pyproject.toml file to optimize caching +COPY {{ project_path_from_build_root }}/pyproject.toml /app/{{ project_path_from_build_root }}/pyproject.toml + +WORKDIR /app/{{ project_path_from_build_root }} + +# Install the required Python packages using uv +RUN uv pip install --system . + +# Copy the project code +COPY {{ project_path_from_build_root }}/project /app/{{ project_path_from_build_root }}/project + +# Set environment variables +ENV PYTHONPATH=/app + +# Run the agent using uvicorn +CMD ["uvicorn", "project.acp:acp", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/src/agentex/lib/cli/templates/voice/Dockerfile.j2 b/src/agentex/lib/cli/templates/voice/Dockerfile.j2 new file mode 100644 index 000000000..4d9f41d45 --- /dev/null +++ b/src/agentex/lib/cli/templates/voice/Dockerfile.j2 @@ -0,0 +1,43 @@ +# syntax=docker/dockerfile:1.3 +FROM python:3.12-slim +COPY --from=ghcr.io/astral-sh/uv:0.6.4 /uv /uvx /bin/ + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + htop \ + vim \ + curl \ + tar \ + python3-dev \ + postgresql-client \ + build-essential \ + libpq-dev \ + gcc \ + cmake \ + netcat-openbsd \ + node \ + npm \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +RUN uv pip install --system --upgrade pip setuptools wheel + +ENV UV_HTTP_TIMEOUT=1000 + +# Copy just the requirements file to optimize caching +COPY {{ project_path_from_build_root }}/requirements.txt /app/{{ project_path_from_build_root }}/requirements.txt + +WORKDIR /app/{{ project_path_from_build_root }} + +# Install the required Python packages +RUN uv pip install --system -r requirements.txt + +# Copy the project code +COPY {{ project_path_from_build_root }}/project /app/{{ project_path_from_build_root }}/project + + +# Set environment variables +ENV PYTHONPATH=/app + +# Run the agent using uvicorn +CMD ["uvicorn", "project.acp:acp", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/src/agentex/lib/cli/templates/voice/README.md.j2 b/src/agentex/lib/cli/templates/voice/README.md.j2 new file mode 100644 index 000000000..a4be1ef0d --- /dev/null +++ b/src/agentex/lib/cli/templates/voice/README.md.j2 @@ -0,0 +1,399 @@ +# {{ agent_name }} - Voice Agent Template + +This is a starter template for building **voice agents** with the AgentEx framework. It provides a pre-configured implementation using the `VoiceAgentBase` class with LiveKit voice integration, conversation state management, guardrails, and streaming support. + +## What's Included + +This template provides: +- ✅ **VoiceAgentBase** - Production-ready base class with 90% of boilerplate handled +- ✅ **State Management** - Automatic conversation history and state persistence +- ✅ **Interruption Handling** - Voice-specific logic for handling user interruptions +- ✅ **Streaming Support** - Real-time streaming with concurrent guardrail checks +- ✅ **Guardrail System** - Extensible policy enforcement framework +- ✅ **LiveKit Ready** - Pre-configured for voice infrastructure +- ✅ **Gemini Integration** - Uses Gemini 2.5 Flash by default + +## What You'll Learn + +- **Voice Agents**: Building conversational AI for voice interactions +- **Tasks**: Grouping mechanism for conversation threads/sessions +- **Messages**: Communication objects within a task (text, data, tool requests/responses) +- **Sync ACP**: Synchronous Agent Communication Protocol with immediate responses +- **State Management**: Persisting conversation state across messages +- **Guardrails**: Enforcing policies and safety boundaries + +## Quick Start + +### 1. Install Dependencies + +{% if use_uv %} +```bash +# Using uv (recommended) +uv sync +``` +{% else %} +```bash +# Using pip +pip install -r requirements.txt +``` +{% endif %} + +### 2. Configure Environment + +Create a `.env` file in the project directory. Choose ONE LLM provider: + +```bash +# ============================================================================ +# LLM CONFIGURATION (choose one) +# ============================================================================ + +# OPTION 1: OpenAI (direct) +# Works anywhere with internet access to api.openai.com +OPENAI_API_KEY=sk-your-openai-key +LLM_MODEL=gpt-4o-mini # optional, this is the default + +# OPTION 2: OpenAI via Proxy (for VPC/firewall environments) +# Use when cluster can't reach api.openai.com directly +OPENAI_API_KEY=your-key +OPENAI_BASE_URL=https://your-internal-proxy.company.com/v1 +LLM_MODEL=gpt-4o-mini + +# OPTION 3: Azure OpenAI +# For Azure-hosted OpenAI deployments +AZURE_OPENAI_API_KEY=your-azure-key +AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com +AZURE_OPENAI_DEPLOYMENT=your-deployment-name +LLM_MODEL=gpt-4o # optional + +# OPTION 4: Scale Groundplane (SGP) for Gemini +# For Scale internal use +SGP_API_KEY=your_sgp_api_key +SGP_ACCOUNT_ID=your_sgp_account_id +LLM_MODEL=gemini-2.0-flash # optional + +# OPTION 5: Google Cloud Vertex AI +# For GCP deployments with Gemini +GOOGLE_GENAI_USE_VERTEXAI=true +GOOGLE_CLOUD_PROJECT=your-gcp-project-id +GOOGLE_CLOUD_LOCATION=us-central1 # optional, this is default +GOOGLE_APPLICATION_CREDENTIALS=/path/to/credentials.json +LLM_MODEL=google/gemini-2.0-flash # optional + +# OPTION 6: Mock Mode (for testing without LLM) +# Returns canned responses for integration testing +MOCK_MODE=true + +# ============================================================================ +# OPTIONAL: LiveKit (for voice infrastructure) +# ============================================================================ +# LIVEKIT_URL=wss://your-livekit-instance.com +# LIVEKIT_API_KEY=your_livekit_api_key +# LIVEKIT_API_SECRET=your_livekit_api_secret +``` + +**Note**: The manifest is pre-configured with `AGENTEX_BASE_URL: ""` for local testing. + +### Deployment Environments + +| Environment | Recommended Config | Notes | +|-------------|-------------------|-------| +| Local dev | OpenAI direct or `MOCK_MODE=true` | Fastest setup | +| Cloud (direct internet) | OpenAI or Azure | Works out of box | +| VPC/Firewall | `OPENAI_BASE_URL` with proxy | Set proxy URL | +| GCP/Vertex AI | `GOOGLE_GENAI_USE_VERTEXAI=true` | For GCP/Gemini | +| Scale internal | SGP credentials | For Gemini models | + +### 3. Run the Agent + +```bash +# Run the agent locally +uv run agentex agents run --manifest manifest.yaml + +# Or with pip-installed agentex: +agentex agents run --manifest manifest.yaml +``` + +### 4. Test Your Agent + +**Option A: Web UI (Recommended)** +```bash +# In the agentex-web directory +make dev + +# Open http://localhost:3000 and select your agent +``` + +**Option B: Development Notebook** +```bash +# Open the included Jupyter notebook +jupyter notebook dev.ipynb +``` + +## Customizing Your Agent + +### 1. Edit the System Prompt + +Modify `get_system_prompt()` in `project/acp.py`: + +```python +def get_system_prompt(self, conversation_state, guardrail_override=None): + return """ + +You are a [describe your agent's role] + + + +- Keep responses under 3 sentences for voice +- Use natural, conversational language +- Speak at a moderate pace +- Be empathetic and warm + + + +- [What can your agent do?] + +""" +``` + +### 2. Add Custom State Fields + +Extend the state model to track agent-specific data: + +```python +class MyAgentState(AgentState): + conversation_phase: str = "introduction" + user_name: Optional[str] = None + key_information_gathered: dict = Field(default_factory=dict) +``` + +### 3. Add Custom Response Fields + +Extend the response model for structured outputs: + +```python +class MyAgentResponse(AgentResponse): + phase_transition: bool = False + new_phase: Optional[str] = None + sentiment: Optional[str] = None +``` + +### 4. Add Tools + +Use the `@function_tool` decorator: + +```python +from agents import function_tool + +@function_tool +async def get_user_info(user_id: str) -> dict: + """Fetch user information from database.""" + # Your implementation + return {"name": "John", "account_id": "12345"} + +# Add to TOOLS list +TOOLS = [get_user_info] +``` + +### 5. Add Guardrails + +Create custom guardrails for policy enforcement: + +```python +from agentex.voice.guardrails import Guardrail + +class MedicalAdviceGuardrail(Guardrail): + def __init__(self): + super().__init__( + name="medical_advice", + outcome_prompt="I cannot provide medical advice. Please consult your healthcare provider." + ) + + async def check(self, user_message, conversation_state): + # Your logic here + medical_keywords = ["diagnose", "treatment", "prescription"] + return not any(kw in user_message.lower() for kw in medical_keywords) + +# Add to GUARDRAILS list +GUARDRAILS = [MedicalAdviceGuardrail()] +``` + +## Voice-Specific Best Practices + +### 1. Keep Responses Concise +Voice responses should be shorter than text: +- ❌ Bad: "I'd be delighted to help you with that question! Let me provide you with a comprehensive answer..." +- ✅ Good: "Sure! Here's what you need to know..." + +### 2. Use Natural Language +Speak like a human, not a robot: +- ❌ Bad: "Please provide the requested information in order to proceed with processing." +- ✅ Good: "What would you like me to help with?" + +### 3. Handle Interruptions Gracefully +The `VoiceAgentBase` automatically handles interruptions, but you can customize behavior in `handle_message_interruption()`. + +### 4. Consider TTS Pacing +- Use punctuation to control pacing (commas, periods, em-dashes) +- Break long responses into shorter chunks +- Avoid complex sentences with nested clauses + +### 5. Add Empathy Guidelines +Voice needs more empathy cues than text: +```python + +- Acknowledge emotions: "That sounds frustrating" or "I understand" +- Use warm, supportive tone +- Pause appropriately (use punctuation) +- Mirror user's energy level + +``` + +## Architecture + +### How VoiceAgentBase Works + +``` +User Message + ↓ +VoiceAgentBase.send_message() + ↓ +[State Management] + ↓ +[Interruption Handling] ← Checks for concurrent messages + ↓ +[Guardrail Checks] ← Concurrent with streaming + ↓ +[LLM Streaming] ← Your custom prompt + ↓ +[State Update] ← Your custom logic + ↓ +[Save State] + ↓ +Response to User +``` + +### What You Implement vs What's Handled + +**You Implement (2 methods):** +- `get_system_prompt()` - Your agent's behavior +- `update_state_and_tracing_from_response()` - State updates + +**VoiceAgentBase Handles:** +- ✅ State persistence and retrieval +- ✅ Conversation history management +- ✅ Message interruption detection +- ✅ Guardrail concurrent execution +- ✅ Streaming with buffering +- ✅ Error handling and tracing + +## Managing Dependencies + +{% if use_uv %} +### Using uv (Recommended) + +```bash +# Add new dependencies +uv add requests openai anthropic + +# Sync dependencies +uv sync + +# Run commands with uv +uv run agentex agents run --manifest manifest.yaml +``` +{% else %} +### Using pip + +```bash +# Add to requirements.txt +echo "requests" >> requirements.txt +echo "openai" >> requirements.txt + +# Install dependencies +pip install -r requirements.txt +``` +{% endif %} + +## Testing + +### Unit Tests + +```bash +pytest test_agent.py -v +``` + +### Manual Testing with Notebook + +The included `dev.ipynb` notebook provides: +- Non-streaming message tests +- Streaming message tests +- State inspection +- Task management examples + +## Deployment + +### Build the Docker Image + +```bash +agentex agents build --manifest manifest.yaml +``` + +### Deploy to Kubernetes + +```bash +# Deploy to dev environment +agentex agents deploy --manifest manifest.yaml --environment dev + +# Check deployment status +kubectl get pods -n team-{{ agent_name }} +``` + +## Troubleshooting + +### Common Issues + +**1. Agent not appearing in web UI** +- Check if agent is running on port 8000 +- Verify `ENVIRONMENT=development` is set +- Check agent logs for errors + +**2. Slow response times** +- Optimize your system prompt (shorter is faster) +- Consider caching expensive operations +- Use faster LLM model (e.g., gemini-flash vs gemini-pro) + +**3. Guardrails timing out** +- Check guardrail LLM calls are completing +- Use faster models for guardrails +- Reduce guardrail complexity + +**4. State not persisting** +- Verify state_id is not None +- Check agentex backend is running +- Look for errors in state.update() calls + +## Next Steps + +1. **Customize the prompt** - Make it specific to your use case +2. **Add tools** - Give your agent capabilities +3. **Add guardrails** - Enforce safety and policy boundaries +4. **Test thoroughly** - Use dev notebook and write tests +5. **Deploy** - Build and deploy to dev environment +6. **Iterate** - Monitor and improve based on user interactions + +## Resources + +- [Agentex Documentation](https://docs.agentex.ai) +- [Voice Agent Best Practices](https://docs.agentex.ai/voice-agents) +- [Guardrails Guide](https://docs.agentex.ai/guardrails) +- [Gemini Model Documentation](https://ai.google.dev/gemini-api/docs) + +## Support + +For questions or issues: +- Open an issue in the agentex repository +- Contact the Agentex team on Slack (#agentex-support) + +--- + +Happy building with Voice Agents! 🎙️🤖 diff --git a/src/agentex/lib/cli/templates/voice/dev.ipynb.j2 b/src/agentex/lib/cli/templates/voice/dev.ipynb.j2 new file mode 100644 index 000000000..d8c10a65a --- /dev/null +++ b/src/agentex/lib/cli/templates/voice/dev.ipynb.j2 @@ -0,0 +1,167 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "36834357", + "metadata": {}, + "outputs": [], + "source": [ + "from agentex import Agentex\n", + "\n", + "client = Agentex(base_url=\"http://localhost:5003\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1c309d6", + "metadata": {}, + "outputs": [], + "source": [ + "AGENT_NAME = \"{{ agent_name }}\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f6e6ef0", + "metadata": {}, + "outputs": [], + "source": [ + "# # (Optional) Create a new task. If you don't create a new task, each message will be sent to a new task. The server will create the task for you.\n", + "\n", + "# import uuid\n", + "\n", + "# TASK_ID = str(uuid.uuid4())[:8]\n", + "\n", + "# rpc_response = client.agents.rpc_by_name(\n", + "# agent_name=AGENT_NAME,\n", + "# method=\"task/create\",\n", + "# params={\n", + "# \"name\": f\"{TASK_ID}-task\",\n", + "# \"params\": {}\n", + "# }\n", + "# )\n", + "\n", + "# task = rpc_response.result\n", + "# print(task)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b03b0d37", + "metadata": {}, + "outputs": [], + "source": [ + "# Test non streaming response\n", + "from agentex.types import TextContent\n", + "\n", + "# The response is expected to be a list of TaskMessage objects, which is a union of the following types:\n", + "# - TextContent: A message with just text content \n", + "# - DataContent: A message with JSON-serializable data content\n", + "# - ToolRequestContent: A message with a tool request, which contains a JSON-serializable request to call a tool\n", + "# - ToolResponseContent: A message with a tool response, which contains response object from a tool call in its content\n", + "\n", + "# When processing the message/send response, if you are expecting more than TextContent, such as DataContent, ToolRequestContent, or ToolResponseContent, you can process them as well\n", + "\n", + "rpc_response = client.agents.send_message(\n", + " agent_name=AGENT_NAME,\n", + " params={\n", + " \"content\": {\"type\": \"text\", \"author\": \"user\", \"content\": \"Hello what can you do?\"},\n", + " \"stream\": False\n", + " }\n", + ")\n", + "\n", + "if not rpc_response or not rpc_response.result:\n", + " raise ValueError(\"No result in response\")\n", + "\n", + "# Extract and print just the text content from the response\n", + "for task_message in rpc_response.result:\n", + " content = task_message.content\n", + " if isinstance(content, TextContent):\n", + " text = content.content\n", + " print(text)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79688331", + "metadata": {}, + "outputs": [], + "source": [ + "# Test streaming response\n", + "from agentex.types.task_message_update import StreamTaskMessageDelta, StreamTaskMessageFull\n", + "from agentex.types.text_delta import TextDelta\n", + "\n", + "\n", + "# The result object of message/send will be a TaskMessageUpdate which is a union of the following types:\n", + "# - StreamTaskMessageStart: \n", + "# - An indicator that a streaming message was started, doesn't contain any useful content\n", + "# - StreamTaskMessageDelta: \n", + "# - A delta of a streaming message, contains the text delta to aggregate\n", + "# - StreamTaskMessageDone: \n", + "# - An indicator that a streaming message was done, doesn't contain any useful content\n", + "# - StreamTaskMessageFull: \n", + "# - A non-streaming message, there is nothing to aggregate, since this contains the full message, not deltas\n", + "\n", + "# Whenn processing StreamTaskMessageDelta, if you are expecting more than TextDeltas, such as DataDelta, ToolRequestDelta, or ToolResponseDelta, you can process them as well\n", + "# Whenn processing StreamTaskMessageFull, if you are expecting more than TextContent, such as DataContent, ToolRequestContent, or ToolResponseContent, you can process them as well\n", + "\n", + "for agent_rpc_response_chunk in client.agents.send_message_stream(\n", + " agent_name=AGENT_NAME,\n", + " params={\n", + " \"content\": {\"type\": \"text\", \"author\": \"user\", \"content\": \"Hello what can you do?\"},\n", + " \"stream\": True\n", + " }\n", + "):\n", + " # We know that the result of the message/send when stream is set to True will be a TaskMessageUpdate\n", + " task_message_update = agent_rpc_response_chunk.result\n", + " # Print oly the text deltas as they arrive or any full messages\n", + " if isinstance(task_message_update, StreamTaskMessageDelta):\n", + " delta = task_message_update.delta\n", + " if isinstance(delta, TextDelta):\n", + " print(delta.text_delta, end=\"\", flush=True)\n", + " else:\n", + " print(f\"Found non-text {type(task_message)} object in streaming message.\")\n", + " elif isinstance(task_message_update, StreamTaskMessageFull):\n", + " content = task_message_update.content\n", + " if isinstance(content, TextContent):\n", + " print(content.content)\n", + " else:\n", + " print(f\"Found non-text {type(task_message)} object in full message.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5e7e042", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/agentex/lib/cli/templates/voice/environments.yaml.j2 b/src/agentex/lib/cli/templates/voice/environments.yaml.j2 new file mode 100644 index 000000000..f038b4775 --- /dev/null +++ b/src/agentex/lib/cli/templates/voice/environments.yaml.j2 @@ -0,0 +1,37 @@ +# Agent Environment Configuration +# ------------------------------ +# This file defines environment-specific settings for your voice agent. +# This DIFFERS from manifest.yaml - use this for per-environment overrides. + +schema_version: "v1" + +environments: + dev: + auth: + principal: + user_id: # TODO: Fill in + account_id: # TODO: Fill in + + # Helm overrides for dev environment + helm_overrides: + replicaCount: 2 + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" + + # Environment-specific variables + env: + - name: LOG_LEVEL + value: "DEBUG" + - name: ENVIRONMENT + value: "development" + + # Add more environments as needed + # staging: + # ... + # prod: + # ... diff --git a/src/agentex/lib/cli/templates/voice/manifest.yaml.j2 b/src/agentex/lib/cli/templates/voice/manifest.yaml.j2 new file mode 100644 index 000000000..bc73067ff --- /dev/null +++ b/src/agentex/lib/cli/templates/voice/manifest.yaml.j2 @@ -0,0 +1,123 @@ +# Agent Manifest Configuration for Voice Agent +# ------------------------------------------ +# This file defines how your voice agent should be built and deployed. + +# Build Configuration +# ------------------ +build: + context: + root: ../ # Keep this as the default root + + # Paths to include in the Docker build context + include_paths: + - {{ project_path_from_build_root }} + + # Path to your agent's Dockerfile + dockerfile: {{ project_path_from_build_root }}/Dockerfile + + # Path to your agent's .dockerignore + dockerignore: {{ project_path_from_build_root }}/.dockerignore + + +# Local Development Configuration +# ----------------------------- +local_development: + agent: + port: 8000 # Port where your local ACP server runs + host_address: host.docker.internal + + paths: + # Path to ACP server file (relative to this manifest.yaml) + acp: project/acp.py + + +# Agent Configuration +# ----------------- +agent: + acp_type: sync # Voice agents use sync ACP + + # Unique name for your agent + name: {{ agent_name }} + + # Description of what your agent does + description: {{ description }} + + # Temporal workflow configuration + # Voice agents typically don't need Temporal + temporal: + enabled: false + + # Voice-specific credentials + # Uncomment and configure if using LiveKit + credentials: [] + # - env_var_name: LIVEKIT_URL + # secret_name: livekit-credentials + # secret_key: url + # - env_var_name: LIVEKIT_API_KEY + # secret_name: livekit-credentials + # secret_key: api-key + # - env_var_name: LIVEKIT_API_SECRET + # secret_name: livekit-credentials + # secret_key: api-secret + + # Environment variables for local development + # For deployed environments, use Kubernetes secrets (see credentials below) + env: + AGENTEX_BASE_URL: "" # Empty = disable platform registration for local testing + # LLM_MODEL: "gpt-4o-mini" # Override default model + + # LLM Credentials (choose one approach based on deployment) + # -------------------------------------------------------- + # OPTION 1: OpenAI (direct or via proxy) + # - env_var_name: OPENAI_API_KEY + # secret_name: openai-api-key-secret + # secret_key: api-key + # - env_var_name: OPENAI_BASE_URL # For proxy/internal endpoints + # secret_name: openai-config + # secret_key: base-url + # + # OPTION 2: Azure OpenAI + # - env_var_name: AZURE_OPENAI_API_KEY + # secret_name: azure-openai-secret + # secret_key: api-key + # - env_var_name: AZURE_OPENAI_ENDPOINT + # secret_name: azure-openai-secret + # secret_key: endpoint + # - env_var_name: AZURE_OPENAI_DEPLOYMENT + # secret_name: azure-openai-secret + # secret_key: deployment + # + # OPTION 3: Scale Groundplane (SGP) + # - env_var_name: SGP_API_KEY + # secret_name: sgp-credentials + # secret_key: api-key + # - env_var_name: SGP_ACCOUNT_ID + # secret_name: sgp-credentials + # secret_key: account-id + + +# Deployment Configuration +# ----------------------- +deployment: + # Container image configuration + image: + repository: "" # Update with your container registry + tag: "latest" + + imagePullSecrets: [] + # - name: my-registry-secret + + # Global deployment settings + global: + # Default replica count + replicaCount: 1 + + # Default resource requirements + # Adjust based on your agent's needs + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" diff --git a/src/agentex/lib/cli/templates/voice/project/acp.py.j2 b/src/agentex/lib/cli/templates/voice/project/acp.py.j2 new file mode 100644 index 000000000..28ce7bf17 --- /dev/null +++ b/src/agentex/lib/cli/templates/voice/project/acp.py.j2 @@ -0,0 +1,360 @@ +"""{{ agent_name }} - Voice Agent powered by LiveKit + +This agent uses the VoiceAgentBase class which provides: +- Automatic state management and conversation history +- Streaming support with interruption handling +- Guardrail system integration +- LiveKit voice infrastructure ready +""" + +import os +from typing import AsyncGenerator, Optional + +# Disable default Agentex tracing for local development +# Remove this block when deploying to Agentex platform +import agentex.lib.core.tracing.tracing_processor_manager as _tpm +_tpm._default_initialized = True + +from agentex.voice import VoiceAgentBase, AgentState, AgentResponse +from agentex.voice.guardrails import Guardrail +from agentex.lib.sdk.fastacp.fastacp import FastACP +from agentex.lib.types.acp import SendMessageParams +from agentex.lib.utils.logging import make_logger +from agentex.types import Span +from agentex.types.task_message_update import TaskMessageUpdate +from pydantic import Field +from openai import AsyncOpenAI + +logger = make_logger(__name__) + +# ============================================================================ +# LLM Configuration +# ============================================================================ +# This template supports multiple LLM providers via environment variables. +# Configure ONE of the following options: +# +# OPTION 1: OpenAI (direct or via proxy) +# OPENAI_API_KEY=sk-... +# OPENAI_BASE_URL=https://api.openai.com/v1 (optional, for proxies) +# LLM_MODEL=gpt-4o-mini (optional) +# +# OPTION 2: Azure OpenAI +# AZURE_OPENAI_API_KEY=... +# AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com +# AZURE_OPENAI_DEPLOYMENT=your-deployment-name +# LLM_MODEL=gpt-4o (optional) +# +# OPTION 3: Scale Groundplane (SGP) for Gemini +# SGP_API_KEY=... +# SGP_ACCOUNT_ID=... +# LLM_MODEL=gemini-2.0-flash (optional) +# +# OPTION 4: Any OpenAI-compatible endpoint +# OPENAI_API_KEY=your-key +# OPENAI_BASE_URL=https://your-internal-proxy.company.com/v1 +# LLM_MODEL=your-model-name +# +# For testing without credentials, set: MOCK_MODE=true +# ============================================================================ + +# Environment variables +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") +OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "") # Empty = use default +AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY", "") +AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT", "") +AZURE_OPENAI_DEPLOYMENT = os.environ.get("AZURE_OPENAI_DEPLOYMENT", "") +SGP_API_KEY = os.environ.get("SGP_API_KEY", "") +SGP_ACCOUNT_ID = os.environ.get("SGP_ACCOUNT_ID", "") +# Vertex AI (Google Cloud) +GOOGLE_CLOUD_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT", "") +GOOGLE_CLOUD_LOCATION = os.environ.get("GOOGLE_CLOUD_LOCATION", "us-central1") +GOOGLE_GENAI_USE_VERTEXAI = os.environ.get("GOOGLE_GENAI_USE_VERTEXAI", "").lower() in ("true", "1", "yes") +LLM_MODEL_ENV = os.environ.get("LLM_MODEL", "") +MOCK_MODE = os.environ.get("MOCK_MODE", "").lower() in ("true", "1", "yes") + +openai_client = None +LLM_MODEL = LLM_MODEL_ENV or "gpt-4o-mini" # Default model + +def configure_llm_client(): + """Configure LLM client based on available environment variables.""" + global openai_client, LLM_MODEL, MOCK_MODE + + # Priority: Vertex AI > Azure > SGP > OpenAI > Mock + + if GOOGLE_GENAI_USE_VERTEXAI and GOOGLE_CLOUD_PROJECT: + # Vertex AI using OpenAI-compatible endpoint + try: + import google.auth + from google.auth.transport.requests import Request + + # Get credentials from environment (GOOGLE_APPLICATION_CREDENTIALS) + credentials, project = google.auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + credentials.refresh(Request()) + + # Vertex AI OpenAI-compatible endpoint + vertex_base_url = f"https://{GOOGLE_CLOUD_LOCATION}-aiplatform.googleapis.com/v1beta1/projects/{GOOGLE_CLOUD_PROJECT}/locations/{GOOGLE_CLOUD_LOCATION}/endpoints/openapi" + + logger.info(f"Configuring Vertex AI client: project={GOOGLE_CLOUD_PROJECT}, location={GOOGLE_CLOUD_LOCATION}") + openai_client = AsyncOpenAI( + api_key=credentials.token, # Use OAuth token as API key + base_url=vertex_base_url, + ) + LLM_MODEL = LLM_MODEL_ENV or "google/gemini-2.0-flash" + logger.info(f"Vertex AI configured for model: {LLM_MODEL}") + return + except Exception as e: + logger.warning(f"Failed to configure Vertex AI: {e}. Falling back to other providers.") + + if AZURE_OPENAI_API_KEY and AZURE_OPENAI_ENDPOINT: + # Azure OpenAI + from openai import AsyncAzureOpenAI + logger.info(f"Configuring Azure OpenAI client: {AZURE_OPENAI_ENDPOINT}") + openai_client = AsyncAzureOpenAI( + api_key=AZURE_OPENAI_API_KEY, + azure_endpoint=AZURE_OPENAI_ENDPOINT, + api_version="2024-02-01", + ) + LLM_MODEL = LLM_MODEL_ENV or AZURE_OPENAI_DEPLOYMENT or "gpt-4o" + logger.info(f"Azure OpenAI configured with model: {LLM_MODEL}") + return + + if SGP_API_KEY and SGP_ACCOUNT_ID: + # Scale Groundplane with Gemini + logger.info(f"Configuring SGP client with account ID: {SGP_ACCOUNT_ID[:8]}...") + openai_client = AsyncOpenAI( + api_key=SGP_API_KEY, + base_url="https://api.egp.scale.com/beta", + default_headers={"x-selected-account-id": SGP_ACCOUNT_ID}, + ) + LLM_MODEL = LLM_MODEL_ENV or "gemini-2.0-flash" + logger.info(f"SGP client configured for model: {LLM_MODEL}") + return + + if OPENAI_API_KEY: + # OpenAI (direct or via proxy/custom endpoint) + client_kwargs = {"api_key": OPENAI_API_KEY} + if OPENAI_BASE_URL: + client_kwargs["base_url"] = OPENAI_BASE_URL + logger.info(f"Configuring OpenAI client with custom base URL: {OPENAI_BASE_URL}") + else: + logger.info("Configuring OpenAI client (direct)") + + openai_client = AsyncOpenAI(**client_kwargs) + LLM_MODEL = LLM_MODEL_ENV or "gpt-4o-mini" + logger.info(f"OpenAI client configured for model: {LLM_MODEL}") + return + + # No credentials - enable mock mode + MOCK_MODE = True + logger.warning("No LLM credentials set - running in MOCK MODE (returns test responses)") + +# Initialize LLM client +configure_llm_client() + +# Create ACP server +acp = FastACP.create(acp_type="sync") + + +# ============================================================================ +# Custom State and Response Models (optional) +# ============================================================================ + +class {{ agent_class_name }}State(AgentState): + """Custom state for {{ agent_name }}. + + Extend AgentState to add agent-specific fields. + + Example: + conversation_phase: str = "introduction" + user_preferences: dict = Field(default_factory=dict) + """ + pass + + +class {{ agent_class_name }}Response(AgentResponse): + """Custom response for {{ agent_name }}. + + Extend AgentResponse to add structured output fields. + + Example: + phase_transition: bool = False + new_phase: Optional[str] = None + """ + pass + + +# ============================================================================ +# Agent Implementation +# ============================================================================ + +class {{ agent_class_name }}(VoiceAgentBase): + """Voice agent for {{ agent_name }}.""" + + # Specify custom models + state_class = {{ agent_class_name }}State + response_class = {{ agent_class_name }}Response + + def get_system_prompt( + self, + conversation_state: {{ agent_class_name }}State, + guardrail_override: Optional[str] = None + ) -> str: + """Return the system prompt for this agent. + + Args: + conversation_state: Current conversation state + guardrail_override: If provided, use this prompt (for guardrail failures) + + Returns: + System prompt string + """ + if guardrail_override: + return guardrail_override + + # TODO: Customize this prompt for your agent + return """ + +You are a helpful voice assistant for {{ description }}. + + + +- Speak naturally and conversationally +- Be empathetic and warm +- Keep responses concise for voice interaction +- Use appropriate pacing and tone +- Avoid technical jargon unless necessary + + + +# TODO: Describe what this agent can do +- Answer questions about [topic] +- Help with [task] +- Provide information about [subject] + + + +- Always be respectful and professional +- If you don't know something, admit it +- Stay within your defined capabilities +- Redirect inappropriate or off-topic requests politely + +""" + + def update_state_and_tracing_from_response( + self, + conversation_state: {{ agent_class_name }}State, + response_data: {{ agent_class_name }}Response, + span: Span, + ) -> {{ agent_class_name }}State: + """Update state after LLM response. + + Args: + conversation_state: Current conversation state + response_data: Structured response from LLM + span: Tracing span for logging + + Returns: + Updated conversation state + """ + # TODO: Update state based on response_data fields + # Example: + # if response_data.phase_transition: + # conversation_state.conversation_phase = response_data.new_phase + + # Set span output for tracing + span.output = response_data + + return conversation_state + + +# ============================================================================ +# Guardrails (optional) +# ============================================================================ + +# TODO: Add guardrails for your agent +# Example: +# class MyCustomGuardrail(Guardrail): +# def __init__(self): +# super().__init__( +# name="my_guardrail", +# outcome_prompt="I can't help with that." +# ) +# +# async def check(self, user_message, conversation_state): +# return "bad_word" not in user_message.lower() + +GUARDRAILS = [ + # Add your guardrails here + # MyCustomGuardrail(), +] + + +# ============================================================================ +# Tools (optional) +# ============================================================================ + +# TODO: Add tools for your agent using the @function_tool decorator +# Example: +# from agents import function_tool +# +# @function_tool +# async def get_weather(location: str) -> str: +# """Get the weather for a location.""" +# return f"The weather in {location} is sunny." + +TOOLS = [ + # Add your tools here +] + + +# ============================================================================ +# Agent Initialization +# ============================================================================ + +AGENT = {{ agent_class_name }}( + agent_name="{{ agent_name }}", + llm_model=LLM_MODEL, # Set by LLM configuration above + tools=TOOLS, + guardrails=GUARDRAILS, + openai_client=openai_client, +) + + +# ============================================================================ +# ACP Handler +# ============================================================================ + +@acp.on_message_send +async def handle_message_send( + params: SendMessageParams, +) -> AsyncGenerator[TaskMessageUpdate, None]: + """Handle incoming voice messages with streaming support. + + This is the main entry point for the agent. It delegates to + VoiceAgentBase.send_message() which handles all the complex logic. + """ + logger.info(f"Received message for {{ agent_name }}") + + # Mock mode for testing without LLM credentials + if MOCK_MODE: + from agentex.types import TextContent + from agentex.types.task_message_update import StreamTaskMessageFull + + user_message = params.content.content if hasattr(params.content, 'content') else str(params.content) + logger.info(f"MOCK MODE: Returning test response for: {user_message[:50]}...") + + yield StreamTaskMessageFull( + type="full", + index=0, + content=TextContent( + type="text", + author="agent", + content=f"[MOCK MODE] Hello! I received your message: '{user_message[:100]}'. This is a test response. Configure OPENAI_API_KEY or SGP credentials for real LLM responses.", + ), + ) + return + + async for chunk in AGENT.send_message(params): + yield chunk diff --git a/src/agentex/lib/cli/templates/voice/pyproject.toml.j2 b/src/agentex/lib/cli/templates/voice/pyproject.toml.j2 new file mode 100644 index 000000000..57ced60c6 --- /dev/null +++ b/src/agentex/lib/cli/templates/voice/pyproject.toml.j2 @@ -0,0 +1,36 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "{{ project_name }}" +version = "0.1.0" +description = "{{ description }}" +requires-python = ">=3.12" +dependencies = [ + "agentex-sdk", # Includes agentex.voice module + "openai", # LLM client (OpenAI, Azure, Vertex AI compatible) + "google-auth", # For Vertex AI authentication + "partial-json-parser", # For streaming JSON parsing +] + +[project.optional-dependencies] +dev = [ + "pytest", + "black", + "isort", + "flake8", + "ipython", + "jupyter", +] + +[tool.hatch.build.targets.wheel] +packages = ["project"] + +[tool.black] +line-length = 88 +target-version = ['py312'] + +[tool.isort] +profile = "black" +line_length = 88 diff --git a/src/agentex/lib/cli/templates/voice/requirements.txt.j2 b/src/agentex/lib/cli/templates/voice/requirements.txt.j2 new file mode 100644 index 000000000..45c24e250 --- /dev/null +++ b/src/agentex/lib/cli/templates/voice/requirements.txt.j2 @@ -0,0 +1,9 @@ +# Core dependencies for voice agents +agentex-sdk # Includes agentex.voice module + +# LLM provider +scale-gp # For Gemini (vertex_ai models) + +# Optional: Add other dependencies here +# openai +# anthropic diff --git a/src/agentex/lib/cli/templates/voice/test_agent.py.j2 b/src/agentex/lib/cli/templates/voice/test_agent.py.j2 new file mode 100644 index 000000000..7de4684f4 --- /dev/null +++ b/src/agentex/lib/cli/templates/voice/test_agent.py.j2 @@ -0,0 +1,70 @@ +""" +Sample tests for AgentEx ACP agent. + +This test suite demonstrates how to test the main AgentEx API functions: +- Non-streaming message sending +- Streaming message sending +- Task creation via RPC + +To run these tests: +1. Make sure the agent is running (via docker-compose or `agentex agents run`) +2. Set the AGENTEX_API_BASE_URL environment variable if not using default +3. Run: pytest test_agent.py -v + +Configuration: +- AGENTEX_API_BASE_URL: Base URL for the AgentEx server (default: http://localhost:5003) +- AGENT_NAME: Name of the agent to test (default: {{ agent_name }}) +""" + +import os +import pytest +from agentex import Agentex + + +# Configuration from environment variables +AGENTEX_API_BASE_URL = os.environ.get("AGENTEX_API_BASE_URL", "http://localhost:5003") +AGENT_NAME = os.environ.get("AGENT_NAME", "{{ agent_name }}") + + +@pytest.fixture +def client(): + """Create an AgentEx client instance for testing.""" + return Agentex(base_url=AGENTEX_API_BASE_URL) + + +@pytest.fixture +def agent_name(): + """Return the agent name for testing.""" + return AGENT_NAME + + +@pytest.fixture +def agent_id(client, agent_name): + """Retrieve the agent ID based on the agent name.""" + agents = client.agents.list() + for agent in agents: + if agent.name == agent_name: + return agent.id + raise ValueError(f"Agent with name {agent_name} not found.") + + +class TestNonStreamingMessages: + """Test non-streaming message sending.""" + + def test_send_message(self, client: Agentex, _agent_name: str): + """Test sending a message and receiving a response.""" + # TODO: Fill in the test based on what data your agent is expected to handle + ... + + +class TestStreamingMessages: + """Test streaming message sending.""" + + def test_send_stream_message(self, client: Agentex, _agent_name: str): + """Test streaming a message and aggregating deltas.""" + # TODO: Fill in the test based on what data your agent is expected to handle + ... + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/src/agentex/voice/__init__.py b/src/agentex/voice/__init__.py new file mode 100644 index 000000000..6ac2e2a5c --- /dev/null +++ b/src/agentex/voice/__init__.py @@ -0,0 +1,31 @@ +"""Voice Agent SDK module for building LiveKit-powered voice agents. + +This module provides base classes and utilities for creating production-ready +voice agents with state management, guardrails, and streaming support. + +Example usage: + from agentex.voice import VoiceAgentBase, AgentState, AgentResponse + + class MyVoiceAgent(VoiceAgentBase): + def get_system_prompt(self, state, guardrail_override=None): + return "You are a helpful voice assistant." + + def update_state_and_tracing_from_response(self, state, response, span): + span.output = response + return state +""" + +from agentex.voice.agent import VoiceAgentBase +from agentex.voice.models import AgentState, AgentResponse, ProcessingInfo +from agentex.voice.guardrails import Guardrail, LLMGuardrail + +__all__ = [ + "VoiceAgentBase", + "AgentState", + "AgentResponse", + "ProcessingInfo", + "Guardrail", + "LLMGuardrail", +] + +__version__ = "0.1.0" diff --git a/src/agentex/voice/agent.py b/src/agentex/voice/agent.py new file mode 100644 index 000000000..8ac92dc20 --- /dev/null +++ b/src/agentex/voice/agent.py @@ -0,0 +1,788 @@ +"""Voice Agent base class for LiveKit-powered voice agents. + +This module provides the VoiceAgentBase class which handles: +- State management and persistence +- Message interruption handling +- Guardrail execution +- Streaming with concurrent guardrail checks +- Conversation history management +""" + +import asyncio +import time +import uuid +from abc import ABC, abstractmethod +from typing import AsyncGenerator, Generic, Optional, Type, TypeVar + +from agentex.lib import adk +from agentex.lib.types.acp import SendMessageParams +from agentex.lib.utils.logging import make_logger +from agentex.types import DataContent, Span, TextContent, TextDelta +from agentex.types.task_message_update import ( + StreamTaskMessageDelta, + StreamTaskMessageDone, + StreamTaskMessageFull, + StreamTaskMessageStart, + TaskMessageUpdate, +) +from agents import FunctionTool, OpenAIChatCompletionsModel +from openai.types.responses import ResponseTextDeltaEvent +from partial_json_parser import MalformedJSON +from partial_json_parser import loads as partial_json_loads +from pydantic import ValidationError + +from agentex.voice.guardrails import Guardrail +from agentex.voice.models import AgentResponse, AgentState, ProcessingInfo + +logger = make_logger(__name__) + +# Timeout for processing info - if older than this, consider it stale/crashed +PROCESSING_TIMEOUT_SECONDS = 5 + + +# Define TypeVars bounded to the base types +TState = TypeVar("TState", bound=AgentState) +TResponse = TypeVar("TResponse", bound=AgentResponse) + + +class VoiceAgentBase(ABC, Generic[TState, TResponse]): + """Base class for voice agents with LiveKit integration. + + This class provides: + - Automatic state management and persistence + - Message interruption handling for voice + - Guardrail system integration + - Streaming with concurrent processing + - Conversation history tracking + + Subclasses must implement: + - get_system_prompt(): Return the LLM system prompt + - update_state_and_tracing_from_response(): Update state after LLM response + + Optional override: + - finish_agent_turn(): Stream additional content after main response + + Example: + class MyVoiceAgent(VoiceAgentBase): + state_class = MyAgentState + response_class = MyAgentResponse + + def get_system_prompt(self, state, guardrail_override=None): + return "You are a helpful assistant." + + def update_state_and_tracing_from_response(self, state, response, span): + span.output = response + return state + """ + + # Subclasses must define these class attributes + state_class: Type[TState] = AgentState # type: ignore + response_class: Type[TResponse] = AgentResponse # type: ignore + + def __init__( + self, + agent_name: str, + llm_model: str, + tools: Optional[list[FunctionTool]] = None, + guardrails: Optional[list[Guardrail]] = None, + openai_client = None, + ): + """Initialize the voice agent. + + Args: + agent_name: Unique name for this agent + llm_model: LLM model identifier (e.g., "vertex_ai/gemini-2.5-flash") + tools: List of FunctionTools for the agent to use + guardrails: List of Guardrails to enforce + openai_client: OpenAI-compatible client (defaults to adk default) + """ + self.agent_name = agent_name + self.llm_model = llm_model + self.tools = tools or [] + self.guardrails = guardrails or [] + self.openai_client = openai_client + + ### Abstract methods - must be implemented by subclasses + + @abstractmethod + def get_system_prompt( + self, conversation_state: TState, guardrail_override: Optional[str] = None + ) -> str: + """Generate the system prompt for the agent LLM. + + Args: + conversation_state: Current conversation state + guardrail_override: If provided, use this as the prompt (for guardrail failures) + + Returns: + System prompt string + """ + pass + + @abstractmethod + def update_state_and_tracing_from_response( + self, conversation_state: TState, response_data: TResponse, span: Span + ) -> TState: + """Update and return the conversation state based on the response data from the agent LLM. + + Args: + conversation_state: Current conversation state + response_data: Structured response from LLM + span: Tracing span for logging + + Returns: + Updated conversation state + """ + pass + + async def finish_agent_turn( + self, conversation_state: TState + ) -> AsyncGenerator[TaskMessageUpdate, None]: + """Stream any additional chunks to the user after the main response. + + Default implementation yields nothing. Override this method if your agent + needs to stream additional content after the main response. + + Args: + conversation_state: Current conversation state + + Yields: + TaskMessageUpdate objects for additional content + """ + return + yield # This line is never reached but makes it an async generator + + ### State management / interruption handling methods + + async def get_or_create_conversation_state( + self, task_id: str, agent_id: str + ) -> tuple[TState, str | None]: + """Get existing conversation state or create a new one. + + Args: + task_id: Unique task identifier + agent_id: Unique agent identifier + + Returns: + Tuple of (state, state_id) + """ + try: + # Try to get existing state + task_state = await adk.state.get_by_task_and_agent(task_id=task_id, agent_id=agent_id) + + if task_state and task_state.state: + # Parse existing state + if isinstance(task_state.state, dict): + return self.state_class(**task_state.state), task_state.id + else: + return task_state.state, task_state.id + else: + # Create new state + new_state = self.state_class() + created_state = await adk.state.create( + task_id=task_id, agent_id=agent_id, state=new_state + ) + return new_state, created_state.id + + except Exception as e: + logger.warning(f"Could not retrieve state, creating new: {e}") + # Fallback to new state + new_state = self.state_class() + try: + created_state = await adk.state.create( + task_id=task_id, agent_id=agent_id, state=new_state + ) + return new_state, created_state.id + except Exception: + # If creation fails, just use in-memory state + return new_state, None + + async def check_if_interrupted(self, task_id: str, agent_id: str, my_message_id: str) -> bool: + """Check if this message's processing has been interrupted by another message. + + Args: + task_id: Unique task identifier + agent_id: Unique agent identifier + my_message_id: ID of the message being processed + + Returns: + True if interrupted, False otherwise + """ + state, _ = await self.get_or_create_conversation_state(task_id, agent_id) + if state.processing_info: + return ( + state.processing_info.interrupted + and state.processing_info.message_id == my_message_id + ) + return False + + async def clear_processing_info( + self, task_id: str, agent_id: str, state_id: str | None + ) -> None: + """Clear processing_info to signal processing has ended. + + Args: + task_id: Unique task identifier + agent_id: Unique agent identifier + state_id: State ID for updates (can be None) + """ + if not state_id: + return + + try: + state, _ = await self.get_or_create_conversation_state(task_id, agent_id) + if state.processing_info: + state.processing_info = None + await adk.state.update( + state_id=state_id, + task_id=task_id, + agent_id=agent_id, + state=state, + ) + logger.info("Cleared processing_info") + except Exception as e: + logger.warning(f"Failed to clear processing_info: {e}") + + async def wait_for_processing_clear( + self, + task_id: str, + agent_id: str, + interrupted_message_id: str, + timeout: float = 5.0, + poll_interval: float = 0.2, + ) -> bool: + """Wait for the interrupted processor to acknowledge and clear processing_info. + + Args: + task_id: Unique task identifier + agent_id: Unique agent identifier + interrupted_message_id: ID of the message that was interrupted + timeout: Maximum time to wait (seconds) + poll_interval: How often to check (seconds) + + Returns: + True if cleared, False if timeout reached + """ + start_time = time.time() + while time.time() - start_time < timeout: + state, _ = await self.get_or_create_conversation_state(task_id, agent_id) + + # Check if processing_info is cleared or belongs to a different message + if not state.processing_info: + return True + if state.processing_info.message_id != interrupted_message_id: + return True + + await asyncio.sleep(poll_interval) + + # Timeout - old processor may have crashed, proceed anyway + logger.warning("Timeout waiting for interrupted processor to clear") + return False + + async def handle_message_interruption( + self, + task_id: str, + agent_id: str, + state_id: str | None, + new_content: str, + new_message_id: str, + state: TState, + ) -> tuple[str, bool]: + """Handle interruption of active processing when a new message arrives. + + This is critical for voice agents where users may speak over the agent + or correct themselves mid-sentence. + + Args: + task_id: Unique task identifier + agent_id: Unique agent identifier + state_id: State ID for updates + new_content: New message content + new_message_id: ID of the new message + state: Current conversation state + + Returns: + Tuple of (final_content, was_interrupted) where: + - final_content: The message content to process (may be combined with old content) + - was_interrupted: True if we interrupted another message's processing + """ + if not state.processing_info: + # No active processing - process normally + return new_content, False + + # Check if processing info is stale (processor may have crashed) + processing_age = time.time() - state.processing_info.started_at + if processing_age > PROCESSING_TIMEOUT_SECONDS: + logger.warning(f"Processing info is stale (age: {processing_age:.1f}s), ignoring") + return new_content, False + + # Check if already interrupted (another message already took over) + if state.processing_info.interrupted: + logger.info("Processing already interrupted by another message") + return new_content, False + + old_content = state.processing_info.message_content + old_message_id = state.processing_info.message_id + + # Determine final content based on prefix check + if new_content.startswith(old_content): + # New message is a prefix extension - use new content only + final_content = new_content + logger.info("New message is prefix extension of old, using new content only") + else: + # Concatenate old and new content + final_content = old_content + " " + new_content + logger.info("Concatenating old and new message content") + + # Signal interruption to the old processor + state.processing_info.interrupted = True + state.processing_info.interrupted_by = new_message_id + + if state_id: + await adk.state.update( + state_id=state_id, + task_id=task_id, + agent_id=agent_id, + state=state, + ) + + # Wait for old processor to acknowledge and clean up + await self.wait_for_processing_clear(task_id, agent_id, old_message_id) + + return final_content, True + + async def save_state( + self, + state: TState, + state_id: str | None, + task_id: str, + agent_id: str, + ) -> bool: + """Save the conversation state. + + Args: + state: Conversation state to save + state_id: State ID for updates + task_id: Unique task identifier + agent_id: Unique agent identifier + + Returns: + True if saved successfully, False otherwise + """ + if not state_id: + logger.warning("No state_id provided, cannot save state") + return False + + # Clear processing_info and increment version + state.processing_info = None + state.state_version += 1 + + await adk.state.update( + state_id=state_id, + task_id=task_id, + agent_id=agent_id, + state=state, + ) + + logger.info(f"Saved state (version {state.state_version})") + return True + + ### LLM request methods + + async def run_all_guardrails( + self, user_message: str, conversation_state: TState + ) -> tuple[bool, list[Guardrail]]: + """Run all guardrails concurrently. + + Args: + user_message: The user's input message + conversation_state: Current conversation state + + Returns: + Tuple of (all_passed, failed_guardrails): + - all_passed: True if all guardrails passed + - failed_guardrails: List of Guardrail instances that failed + """ + if len(self.guardrails) == 0: + return True, [] + + if len(user_message) <= 5: + logger.info("Skipping guardrails for short message") + return True, [] + + logger.info(f"Running {len(self.guardrails)} guardrails concurrently") + + # Run all guardrails concurrently + results = await asyncio.gather( + *[guardrail.check(user_message, conversation_state) for guardrail in self.guardrails], + return_exceptions=True, + ) + + # Process results + failed_guardrails = [] + all_passed = True + + for guardrail, result in zip(self.guardrails, results): + if isinstance(result, Exception): + logger.error(f"Guardrail {guardrail.name} raised exception: {result}") + failed_guardrails.append(guardrail) + all_passed = False + elif not result: + logger.warning(f"Guardrail {guardrail.name} failed") + failed_guardrails.append(guardrail) + all_passed = False + + logger.info(f"Guardrail check complete. All passed: {all_passed}") + return all_passed, failed_guardrails + + async def stream_response( + self, + conversation_state: TState, + max_turns: int = 1000, + guardrail_override: Optional[str] = None, + ) -> AsyncGenerator[str, None]: + """Call an LLM with streaming using the OpenAI Agents SDK. + + Args: + conversation_state: Current conversation state + max_turns: Maximum number of conversation turns to include + guardrail_override: Override prompt (used when guardrail fails) + + Yields: + Text deltas from the LLM + """ + recent_context = conversation_state.conversation_history + if max_turns is not None and len(recent_context) > max_turns: + recent_context = recent_context[-max_turns:] + + system_prompt = self.get_system_prompt(conversation_state, guardrail_override) + output_type = self.response_class if guardrail_override is None else None + + try: + result = await adk.providers.openai.run_agent_streamed( + input_list=recent_context, + mcp_server_params=[], + agent_name=self.agent_name, + agent_instructions=system_prompt, + model=OpenAIChatCompletionsModel( + model=self.llm_model, + openai_client=self.openai_client, + ), + tools=self.tools, + output_type=output_type, + max_turns=max_turns, + ) + + async for event in result.stream_events(): + if event.type == "raw_response_event" and isinstance( + event.data, ResponseTextDeltaEvent + ): + yield event.data.delta + + except Exception as e: + logger.error(f"Streaming LLM call failed: {e}", exc_info=True) + raise e + + ### Message handling methods + + async def handle_data_content_message( + self, + params: SendMessageParams, + span: Span, + ) -> AsyncGenerator[TaskMessageUpdate, None]: + """Handle DataContent messages - initialize conversation state to a specific point. + + Args: + params: Message parameters + span: Tracing span + + Yields: + TaskMessageUpdate with confirmation or error + """ + try: + new_state = self.state_class.model_validate(params.content.data) + created_state = await adk.state.create( + task_id=params.task.id, agent_id=params.agent.id, state=new_state + ) + response = f"Successfully initialized conversation state. State ID: {created_state.id}" + span.output = {"response_text": response} + yield StreamTaskMessageFull( + type="full", + index=0, + content=TextContent( + type="text", + author="agent", + content=response, + ), + ) + except ValidationError as e: + raise ValueError(f"Failed to create conversation state from provided data content: {e}") + + async def handle_text_content_message( + self, + params: SendMessageParams, + conversation_state: TState, + state_id: str | None, + span: Span, + message_id: str, + message_content: str, + ) -> AsyncGenerator[TaskMessageUpdate, None]: + """Handle TextContent messages with guardrails and streaming. + + This is the main message processing pipeline for voice agents. + + Args: + params: Message parameters + conversation_state: Current conversation state + state_id: State ID for updates + span: Tracing span + message_id: Unique message ID + message_content: The text content to process + + Yields: + TaskMessageUpdate objects for streaming response + """ + # Add user message to conversation history + conversation_state.conversation_history.append({"role": "user", "content": message_content}) + + # Start both guardrails and streaming concurrently + guardrail_task = asyncio.create_task( + self.run_all_guardrails(message_content, conversation_state) + ) + # Create an async generator that we'll consume + stream_generator = self.stream_response(conversation_state) + + # Buffer to store streaming chunks while waiting for guardrails + buffered_chunks = [] + full_json_response = "" + assistant_response_text = "" + guardrails_completed = False + guardrails_passed = False + failed_guardrails = [] + + # Consume stream and buffer until guardrails complete + try: + yield StreamTaskMessageStart( + type="start", + index=0, + content=TextContent(author="agent", content=""), + ) + async for chunk in stream_generator: + full_json_response += chunk + + # Check for interruption periodically + if await self.check_if_interrupted(params.task.id, params.agent.id, message_id): + logger.info("Processing interrupted by newer message, stopping") + await self.clear_processing_info(params.task.id, params.agent.id, state_id) + if not guardrail_task.done(): + guardrail_task.cancel() + yield StreamTaskMessageDone(type="done", index=0) + return + + # Check if guardrails have completed + if not guardrails_completed and guardrail_task.done(): + guardrails_completed = True + guardrails_passed, failed_guardrails = guardrail_task.result() + + if not guardrails_passed: + # Guardrails failed - stop processing + logger.warning(f"Guardrails failed: {[g.name for g in failed_guardrails]}") + break + else: + # Guardrails passed - yield all buffered chunks + logger.info("Guardrails passed, streaming response to user") + for buffered_chunk in buffered_chunks: + yield buffered_chunk + # Clear buffer as we've yielded everything + buffered_chunks.clear() + + # Process this chunk for streaming + try: + new_assistant_response_text = partial_json_loads(full_json_response).get( + "response_text", assistant_response_text + ) + if len(new_assistant_response_text) > len(assistant_response_text): + text_delta = new_assistant_response_text[len(assistant_response_text) :] + delta_message = StreamTaskMessageDelta( + type="delta", + index=0, + delta=TextDelta(text_delta=text_delta, type="text"), + ) + + if guardrails_completed and guardrails_passed: + # Guardrails already passed, stream directly + yield delta_message + else: + # Guardrails still running, buffer the chunk + buffered_chunks.append(delta_message) + + assistant_response_text = new_assistant_response_text + except MalformedJSON: + # usually this happens at the start of the stream + continue + + # If guardrails haven't completed yet, wait for them + if not guardrails_completed: + guardrails_passed, failed_guardrails = await guardrail_task + guardrails_completed = True + + if not guardrails_passed: + logger.warning(f"Guardrails failed: {[g.name for g in failed_guardrails]}") + else: + # Guardrails passed - yield all buffered chunks + logger.info( + "Guardrails passed (after streaming completed), yielding buffered response" + ) + for buffered_chunk in buffered_chunks: + yield buffered_chunk + + # If guardrails failed, stream using the prompt override + if not guardrails_passed: + # Use the first failed guardrail's prompt + failed_guardrail = failed_guardrails[0] + assistant_response_text = "" + async for text_delta in self.stream_response( + conversation_state, guardrail_override=failed_guardrail.outcome_prompt + ): + assistant_response_text += text_delta + yield StreamTaskMessageDelta( + type="delta", + index=0, + delta=TextDelta(text_delta=text_delta, type="text"), + ) + + span.output = { + "response_text": assistant_response_text, + "guardrails_hit": [gr.name for gr in failed_guardrails], + } + + else: + # Process the complete response to update state + response_data = self.response_class.model_validate_json(full_json_response) + conversation_state = self.update_state_and_tracing_from_response( + conversation_state, response_data, span + ) + + # Add agent response to conversation history + conversation_state.conversation_history.append( + {"role": "assistant", "content": assistant_response_text} + ) + + # Output any additional messages we want to surface to the user + async for update in self.finish_agent_turn(conversation_state): + yield update + + # Save updated state + if state_id: + await self.save_state( + state=conversation_state, + state_id=state_id, + task_id=params.task.id, + agent_id=params.agent.id, + ) + + except asyncio.CancelledError: + # Handle cancellation gracefully + logger.info("Streaming cancelled") + raise + except Exception as stream_error: + # Cancel guardrail task if it's still running + if not guardrail_task.done(): + guardrail_task.cancel() + raise stream_error + finally: + # Always clear processing_info + await self.clear_processing_info(params.task.id, params.agent.id, state_id) + yield StreamTaskMessageDone(type="done", index=0) + + async def send_message( + self, + params: SendMessageParams, + ) -> AsyncGenerator[TaskMessageUpdate, None]: + """Main entry point to send a message request to a voice agent. + + This is the method called by the ACP handler. It orchestrates: + - State retrieval/creation + - Interruption handling + - Routing to appropriate content handlers + - Error handling + + Args: + params: Message parameters from ACP + + Yields: + TaskMessageUpdate objects for streaming response + """ + # Use task_id as trace_id for consistency + trace_id = params.task.id + async with adk.tracing.span( + trace_id=trace_id, + name="handle_message_send", + input=params, + ) as span: + try: + if isinstance(params.content, DataContent): + # If DataContent is sent, try to initialize state from the sent data + async for update in self.handle_data_content_message(params, span): + yield update + + elif isinstance(params.content, TextContent): + # if TextContent is sent, process it as a voice message + # Generate a unique message ID for this processing request + message_id = f"{params.task.id}:{uuid.uuid4()}" + new_content = params.content.content + + # Get or create conversation state + conversation_state, state_id = await self.get_or_create_conversation_state( + params.task.id, params.agent.id + ) + + # Handle interruption if there's active processing + final_content, was_interrupted = await self.handle_message_interruption( + task_id=params.task.id, + agent_id=params.agent.id, + state_id=state_id, + new_content=new_content, + new_message_id=message_id, + state=conversation_state, + ) + + if was_interrupted: + # Re-read state after interruption cleanup + conversation_state, state_id = await self.get_or_create_conversation_state( + params.task.id, params.agent.id + ) + + # Set up processing_info for this message + conversation_state.processing_info = ProcessingInfo( + message_id=message_id, + message_content=final_content, + started_at=time.time(), + ) + if state_id: + await adk.state.update( + state_id=state_id, + task_id=params.task.id, + agent_id=params.agent.id, + state=conversation_state, + ) + + # Delegate to TextContent handler with guardrails + async for update in self.handle_text_content_message( + params, conversation_state, state_id, span, message_id, final_content + ): + yield update + + except Exception as e: + logger.error(f"Error processing voice message: {e}", exc_info=True) + # Return error message to user + span.output = {"error": str(e)} + yield StreamTaskMessageFull( + type="full", + index=1, + content=TextContent( + type="text", + author="agent", + content="I apologize, but I encountered an error. Could you please try again?", + ), + ) diff --git a/src/agentex/voice/guardrails.py b/src/agentex/voice/guardrails.py new file mode 100644 index 000000000..3ddbdaa2d --- /dev/null +++ b/src/agentex/voice/guardrails.py @@ -0,0 +1,304 @@ +"""Guardrail system for voice agents. + +This module provides base classes for implementing guardrails that check +user messages for policy violations, inappropriate content, or off-topic discussions. +""" + +import os +from abc import ABC, abstractmethod +from typing import Optional + +from agentex.lib import adk +from agentex.lib.utils.logging import make_logger +from agents import OpenAIChatCompletionsModel +from pydantic import BaseModel, Field + +from agentex.voice.models import AgentState + +logger = make_logger(__name__) + + +# ============================================================================ +# Guardrail Response Models +# ============================================================================ + + +class GuardrailResponse(BaseModel): + """Structured response for guardrail checks.""" + + pass_: bool = Field(alias="pass", description="Whether the check passed") + reason: str = Field(description="Brief explanation of the decision") + + +# ============================================================================ +# Base Guardrail Classes +# ============================================================================ + + +class Guardrail(ABC): + """Abstract base class for guardrails. + + Guardrails are checks that run on user messages to enforce policies, + detect inappropriate content, or guide conversations back on track. + + When a guardrail fails: + 1. The agent stops processing the normal LLM response + 2. The guardrail's outcome_prompt is used instead + 3. The agent responds with the guardrail-specific message + + Example: + class MyGuardrail(Guardrail): + def __init__(self): + super().__init__( + name="my_guardrail", + outcome_prompt="Sorry, I can't help with that." + ) + + async def check(self, user_message, conversation_state): + # Return True if passed, False if failed + return "bad_word" not in user_message.lower() + """ + + def __init__( + self, + name: str, + outcome_prompt: Optional[str] = None, + outcome_prompt_file: Optional[str] = None, + ): + """Initialize the guardrail. + + Args: + name: Unique identifier for this guardrail + outcome_prompt: Message to display when guardrail fails + outcome_prompt_file: Path to file containing outcome prompt + + Raises: + ValueError: If neither outcome_prompt nor outcome_prompt_file is provided + FileNotFoundError: If outcome_prompt_file doesn't exist + """ + self.name = name + if outcome_prompt is None and outcome_prompt_file is None: + raise ValueError("Either outcome_prompt or outcome_prompt_file must be provided.") + + if outcome_prompt is not None: + self.outcome_prompt = outcome_prompt + else: + # Load from file + if not os.path.exists(outcome_prompt_file): + raise FileNotFoundError(f"Outcome prompt file not found: {outcome_prompt_file}") + + with open(outcome_prompt_file, "r") as f: + self.outcome_prompt = f.read() + + @abstractmethod + async def check(self, user_message: str, conversation_state: AgentState) -> bool: + """Check if the user message passes this guardrail. + + Args: + user_message: The user's input message + conversation_state: Current conversation state + + Returns: + True if the message passes this guardrail, False otherwise + """ + pass + + +class LLMGuardrail(Guardrail): + """Base class for LLM-based guardrails that use prompts and structured output. + + This class provides a pattern for implementing guardrails that: + 1. Use an LLM to classify whether a message violates a policy + 2. Return structured GuardrailResponse with pass/fail and reasoning + 3. Load prompts from Jinja2 templates + 4. Support context-specific customization + + Subclasses should define: + - name: Unique identifier for the guardrail + - outcome_prompt: Message to display when guardrail fails + - prompt_template: Path to Jinja2 template file (.j2) + - model: Optional LLM model to use + + Example: + class MedicalAdviceGuardrail(LLMGuardrail): + def __init__(self): + super().__init__( + name="medical_advice", + outcome_prompt="I can't provide medical advice. Please consult your doctor.", + prompt_template="prompts/medical_advice_check.j2", + model="vertex_ai/gemini-2.5-flash-lite", + ) + + The prompt template receives these variables: + - user_message: The current user message + - previous_assistant_message: The last agent message (if any) + - Any additional kwargs passed to __init__ + """ + + def __init__( + self, + name: str, + outcome_prompt: str, + prompt_template: str, + model: str = "vertex_ai/gemini-2.5-flash-lite", + openai_client = None, + **template_kwargs, + ): + """Initialize LLM-based guardrail. + + Args: + name: Unique identifier for the guardrail + outcome_prompt: Message to display when guardrail fails + prompt_template: Path to Jinja2 template file (.j2) for the check prompt + model: LLM model to use for classification + openai_client: OpenAI-compatible client (defaults to adk default) + **template_kwargs: Additional variables to pass to the prompt template + """ + super().__init__(name=name, outcome_prompt=outcome_prompt) + self.prompt_template = prompt_template + self.model = model + self.openai_client = openai_client + self.template_kwargs = template_kwargs + + async def check(self, user_message: str, conversation_state: AgentState) -> bool: + """Check if the message passes this guardrail using LLM classification. + + Args: + user_message: The user's input message + conversation_state: Current conversation state + + Returns: + True if the message passes, False otherwise + + Raises: + Exception: If LLM call fails + """ + # Extract previous assistant message for context (if available) + latest_assistant_message = None + assistant_messages = [ + msg for msg in conversation_state.conversation_history if msg["role"] == "assistant" + ] + if assistant_messages: + latest_assistant_message = assistant_messages[-1]["content"] + + # Load and render the prompt template + # Note: Subclasses should implement their own prompt loading logic + # or use agentex.lib.utils.jinja helpers + prompt = self._load_prompt( + user_message=user_message, + previous_assistant_message=latest_assistant_message, + **self.template_kwargs + ) + + agent_instructions = self._get_agent_instructions() + + try: + # Call LLM with structured output + result = await adk.providers.openai.run_agent( + input_list=[{"role": "user", "content": prompt}], + mcp_server_params=[], + agent_name=f"guardrail_{self.name}", + agent_instructions=agent_instructions, + model=OpenAIChatCompletionsModel(model=self.model, openai_client=self.openai_client), + tools=[], + output_type=GuardrailResponse, + ) + # Parse the response + response = result.final_output + + # Log the guardrail check result + logger.info( + f"Guardrail '{self.name}' check - Pass: {response.pass_} - Reason: {response.reason}" + ) + + return response.pass_ + + except Exception as e: + logger.error(f"Guardrail '{self.name}' LLM call failed: {e}", exc_info=True) + raise e + + def _load_prompt(self, user_message: str, previous_assistant_message: Optional[str], **kwargs) -> str: + """Load and render the prompt template. + + This is a placeholder method. Subclasses should override this + to implement their own prompt loading logic using Jinja2. + + Args: + user_message: The current user message + previous_assistant_message: The last agent message (if any) + **kwargs: Additional template variables + + Returns: + Rendered prompt string + """ + # Default implementation - subclasses should override + return f""" +Evaluate if this user message violates the {self.name} policy: + +User message: {user_message} + +Previous assistant message: {previous_assistant_message or "None"} + +Respond with a JSON object containing: +- "pass": true if the message is acceptable, false if it violates the policy +- "reason": brief explanation of your decision +""" + + def _get_agent_instructions(self) -> str: + """Get the agent instructions for the guardrail LLM. + + Subclasses can override this to provide custom instructions. + + Returns: + Agent instructions string + """ + return f"""You are a policy compliance classifier for the {self.name} guardrail. + +Evaluate the user's message and determine if it passes or fails the policy check. + +Respond ONLY with valid JSON matching this schema: +{{ + "pass": true/false, + "reason": "brief explanation" +}} + +Be objective and consistent in your evaluations.""" + + +# ============================================================================ +# Example Guardrail Implementations +# ============================================================================ + + +class SampleMedicalEmergencyGuardrail(Guardrail): + """Example guardrail that checks for medical emergency mentions. + + This is a simple keyword-based example. Production implementations + should use LLMGuardrail for more sophisticated detection. + """ + + def __init__(self): + super().__init__( + name="medical_emergency", + outcome_prompt="If you're experiencing a medical emergency, please call 911 or your local emergency services immediately.", + ) + + async def check(self, user_message: str, conversation_state: AgentState) -> bool: + """Check if the message contains emergency keywords. + + Args: + user_message: The user's input message + conversation_state: Current conversation state + + Returns: + True if no emergency detected, False if emergency keywords found + """ + emergency_keywords = ["emergency", "911", "ambulance", "can't breathe"] + message_lower = user_message.lower() + + for keyword in emergency_keywords: + if keyword in message_lower: + logger.warning(f"Medical emergency keyword detected: {keyword}") + return False + + return True diff --git a/src/agentex/voice/models.py b/src/agentex/voice/models.py new file mode 100644 index 000000000..4e88c7a19 --- /dev/null +++ b/src/agentex/voice/models.py @@ -0,0 +1,53 @@ +"""Data models for voice agent state and responses.""" + +from typing import Optional + +from pydantic import BaseModel, Field + + +class ProcessingInfo(BaseModel): + """Processing information for distributed coordination and interruption handling. + + This model tracks the current message being processed to enable: + - Interruption detection when a new message arrives + - Prefix checking to avoid duplicate processing + - Timeout detection for crashed processors + """ + + message_id: str # Unique ID for this message processing + message_content: str # Content being processed (for prefix checking) + started_at: float # Unix timestamp + interrupted: bool = False # Signal to stop processing + interrupted_by: Optional[str] = None # ID of interrupting message + + +class AgentState(BaseModel): + """Base state model for voice agent conversations. + + This tracks the conversation history and processing information. + Subclass this to add agent-specific state fields. + + Example: + class MyAgentState(AgentState): + custom_field: str = "default" + conversation_phase: str = "introduction" + """ + + conversation_history: list[dict[str, str]] = Field(default_factory=list) + processing_info: Optional[ProcessingInfo] = None + state_version: int = 0 # Increment on each successful save + + +class AgentResponse(BaseModel): + """Base response model for voice agent LLM outputs. + + This defines the structured output format from the LLM. + Subclass this to add agent-specific response fields. + + Example: + class MyAgentResponse(AgentResponse): + phase_transition: bool = False + new_phase: Optional[str] = None + """ + + response_text: str = Field(description="The agent's response to the user") diff --git a/src/agentex/voice/py.typed b/src/agentex/voice/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/uv.lock b/uv.lock index 0bbcf36b0..3aa7c132f 100644 --- a/uv.lock +++ b/uv.lock @@ -8,7 +8,7 @@ resolution-markers = [ [[package]] name = "agentex-sdk" -version = "0.8.0" +version = "0.9.1" source = { editable = "." } dependencies = [ { name = "aiohttp" }, @@ -20,6 +20,7 @@ dependencies = [ { name = "ddtrace" }, { name = "distro" }, { name = "fastapi" }, + { name = "google-auth" }, { name = "httpx" }, { name = "ipykernel" }, { name = "jinja2" }, @@ -31,6 +32,7 @@ dependencies = [ { name = "mcp", extra = ["cli"] }, { name = "openai" }, { name = "openai-agents" }, + { name = "partial-json-parser" }, { name = "pydantic" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -80,6 +82,7 @@ requires-dist = [ { name = "ddtrace", specifier = ">=3.13.0" }, { name = "distro", specifier = ">=1.7.0,<2" }, { name = "fastapi", specifier = ">=0.115.0,<0.116" }, + { name = "google-auth", specifier = ">=2.0.0" }, { name = "httpx", specifier = ">=0.27.2,<0.28" }, { name = "httpx-aiohttp", marker = "extra == 'aiohttp'", specifier = ">=0.1.9" }, { name = "ipykernel", specifier = ">=6.29.5" }, @@ -92,6 +95,7 @@ requires-dist = [ { name = "mcp", extras = ["cli"], specifier = ">=1.4.1" }, { name = "openai", specifier = ">=2.2,<3" }, { name = "openai-agents", specifier = "==0.4.2" }, + { name = "partial-json-parser", specifier = ">=0.2.1" }, { name = "pydantic", specifier = ">=2.0.0,<3" }, { name = "pytest", specifier = ">=8.4.0" }, { name = "pytest-asyncio", specifier = ">=1.0.0" }, @@ -102,7 +106,7 @@ requires-dist = [ { name = "rich", specifier = ">=13.9.2,<14" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.3.4" }, { name = "scale-gp", specifier = ">=0.1.0a59" }, - { name = "scale-gp-beta", specifier = "==0.1.0a20" }, + { name = "scale-gp-beta", specifier = ">=0.1.0a20" }, { name = "sniffio" }, { name = "temporalio", specifier = ">=1.18.2,<2" }, { name = "typer", specifier = ">=0.16,<0.17" }, @@ -1432,6 +1436,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl", hash = "sha256:646204b5ee239c396d040b90f9e272e9a8017c630092bf59980beb62fd033887", size = 106668, upload-time = "2025-08-23T15:15:25.663Z" }, ] +[[package]] +name = "partial-json-parser" +version = "0.2.1.1.post7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/6d/eed37d7ebc1e0bcd27b831c0cf1fe94881934316187c4b30d23f29ea0bd4/partial_json_parser-0.2.1.1.post7.tar.gz", hash = "sha256:86590e1ba6bcb6739a2dfc17d2323f028cb5884f4c6ce23db376999132c9a922", size = 10296, upload-time = "2025-11-17T07:27:41.202Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/32/658973117bf0fd82a24abbfb94fe73a5e86216e49342985e10acce54775a/partial_json_parser-0.2.1.1.post7-py3-none-any.whl", hash = "sha256:145119e5eabcf80cbb13844a6b50a85c68bf99d376f8ed771e2a3c3b03e653ae", size = 10877, upload-time = "2025-11-17T07:27:40.457Z" }, +] + [[package]] name = "pexpect" version = "4.9.0" @@ -2675,4 +2688,4 @@ source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547, upload-time = "2025-06-08T17:06:39.4Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, -] \ No newline at end of file +]