Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ test:
# Run all tests
test-all:
@just test
@just lint
@just format
@just lint
@just typecheck

# Run linting
Expand Down
12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,18 @@ dev = [
]
# Table ingestion features (upload Parquet files, ingest to graph)
tables = [
"pandas>=1.5.0",
"pyarrow>=10.0.0",
"pandas>=2.1.0",
"pyarrow>=17.0.0",
]
# Legacy alias for extensions
extensions = [
"pandas>=1.5.0",
"pyarrow>=10.0.0",
"pandas>=2.1.0",
"pyarrow>=17.0.0",
]
# Install all optional features
all = [
"pandas>=1.5.0",
"pyarrow>=10.0.0",
"pandas>=2.1.0",
"pyarrow>=17.0.0",
]

[build-system]
Expand Down
166 changes: 163 additions & 3 deletions robosystems_client/extensions/graph_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""Graph Management Client

Provides high-level graph management operations with automatic operation monitoring.
Supports both SSE (Server-Sent Events) for real-time updates and polling fallback.
"""

from dataclasses import dataclass
from typing import Dict, Any, Optional, Callable
import time
import json
import logging

import httpx

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -62,19 +66,24 @@ def create_graph_and_wait(
timeout: int = 60,
poll_interval: int = 2,
on_progress: Optional[Callable[[str], None]] = None,
use_sse: bool = True,
) -> str:
"""
Create a graph and wait for completion.

Uses SSE (Server-Sent Events) for real-time progress updates with
automatic fallback to polling if SSE connection fails.

Args:
metadata: Graph metadata
initial_entity: Optional initial entity data
create_entity: Whether to create the entity node and upload initial data.
Only applies when initial_entity is provided. Set to False to create
graph without populating entity data (useful for file-based ingestion).
timeout: Maximum time to wait in seconds
poll_interval: Time between status checks in seconds
poll_interval: Time between status checks in seconds (for polling fallback)
on_progress: Callback for progress updates
use_sse: Whether to try SSE first (default True). Falls back to polling on failure.

Returns:
graph_id when creation completes
Expand All @@ -84,7 +93,6 @@ def create_graph_and_wait(
"""
from ..client import AuthenticatedClient
from ..api.graphs.create_graph import sync_detailed as create_graph
from ..api.operations.get_operation_status import sync_detailed as get_status
from ..models.create_graph_request import CreateGraphRequest
from ..models.graph_metadata import GraphMetadata as APIGraphMetadata

Expand Down Expand Up @@ -151,13 +159,165 @@ def create_graph_and_wait(
on_progress(f"Graph created: {graph_id}")
return graph_id

# Otherwise, poll operation until complete
# Otherwise, wait for operation to complete
if not operation_id:
raise RuntimeError("No graph_id or operation_id in response")

if on_progress:
on_progress(f"Graph creation queued (operation: {operation_id})")

# Try SSE first, fall back to polling
if use_sse:
try:
return self._wait_with_sse(operation_id, timeout, on_progress)
except Exception as e:
logger.debug(f"SSE connection failed, falling back to polling: {e}")
if on_progress:
on_progress("SSE unavailable, using polling...")

# Fallback to polling
return self._wait_with_polling(
operation_id, timeout, poll_interval, on_progress, client
)

def _wait_with_sse(
self,
operation_id: str,
timeout: int,
on_progress: Optional[Callable[[str], None]],
) -> str:
"""
Wait for operation completion using SSE stream.

Args:
operation_id: Operation ID to monitor
timeout: Maximum time to wait in seconds
on_progress: Callback for progress updates

Returns:
graph_id when operation completes

Raises:
RuntimeError: If operation fails
TimeoutError: If operation times out
"""
stream_url = f"{self.base_url}/v1/operations/{operation_id}/stream"
headers = {"X-API-Key": self.token, "Accept": "text/event-stream"}

with httpx.Client(timeout=httpx.Timeout(timeout + 5.0)) as http_client:
with http_client.stream("GET", stream_url, headers=headers) as response:
if response.status_code != 200:
raise RuntimeError(f"SSE connection failed: {response.status_code}")

start_time = time.time()
event_type = None
event_data = ""

for line in response.iter_lines():
# Check timeout
if time.time() - start_time > timeout:
raise TimeoutError(f"Graph creation timed out after {timeout}s")

line = line.strip()

if not line:
# Empty line = end of event, process it
if event_type and event_data:
result = self._process_sse_event(event_type, event_data, on_progress)
if result is not None:
return result
event_type = None
event_data = ""
continue

if line.startswith("event:"):
event_type = line[6:].strip()
elif line.startswith("data:"):
event_data = line[5:].strip()
# Ignore other lines (comments, id, retry, etc.)

raise TimeoutError(f"SSE stream ended without completion after {timeout}s")

def _process_sse_event(
self,
event_type: str,
event_data: str,
on_progress: Optional[Callable[[str], None]],
) -> Optional[str]:
"""
Process a single SSE event.

Returns:
graph_id if operation completed, None to continue waiting

Raises:
RuntimeError: If operation failed
"""
try:
data = json.loads(event_data)
except json.JSONDecodeError:
logger.debug(f"Failed to parse SSE event data: {event_data}")
return None

if event_type == "operation_progress":
if on_progress:
message = data.get("message", "Processing...")
percent = data.get("progress_percent")
if percent is not None:
on_progress(f"{message} ({percent:.0f}%)")
else:
on_progress(message)
return None

elif event_type == "operation_completed":
result = data.get("result", {})
graph_id = result.get("graph_id") if isinstance(result, dict) else None

if graph_id:
if on_progress:
on_progress(f"Graph created: {graph_id}")
return graph_id
else:
raise RuntimeError("Operation completed but no graph_id in result")

elif event_type == "operation_error":
error = data.get("error", "Unknown error")
raise RuntimeError(f"Graph creation failed: {error}")

elif event_type == "operation_cancelled":
reason = data.get("reason", "Operation was cancelled")
raise RuntimeError(f"Graph creation cancelled: {reason}")

# Ignore other event types (keepalive, etc.)
return None

def _wait_with_polling(
self,
operation_id: str,
timeout: int,
poll_interval: int,
on_progress: Optional[Callable[[str], None]],
client: Any,
) -> str:
"""
Wait for operation completion using polling.

Args:
operation_id: Operation ID to monitor
timeout: Maximum time to wait in seconds
poll_interval: Time between status checks
on_progress: Callback for progress updates
client: Authenticated HTTP client

Returns:
graph_id when operation completes

Raises:
RuntimeError: If operation fails
TimeoutError: If operation times out
"""
from ..api.operations.get_operation_status import sync_detailed as get_status

max_attempts = timeout // poll_interval
for attempt in range(max_attempts):
time.sleep(poll_interval)
Expand Down
Loading