diff --git a/src/middlewares/openaiSDKmiddleware.py b/src/middlewares/openaiSDKmiddleware.py new file mode 100644 index 00000000..6a3ec1d7 --- /dev/null +++ b/src/middlewares/openaiSDKmiddleware.py @@ -0,0 +1,145 @@ +import json +from typing import Any, Dict, List, Optional +from fastapi import HTTPException, Request +from .middleware import jwt_middleware +from .ratelimitMiddleware import rate_limit + + +def _extract_pauthkey_from_authorization(request: Request) -> str: + authorization = request.headers.get("Authorization") + if not authorization or not authorization.lower().startswith("bearer "): + raise HTTPException( + status_code=401, + detail="Authorization header with Bearer pauthkey is required.", + ) + + token = authorization.split(" ", 1)[1].strip() + if not token: + raise HTTPException(status_code=401, detail="Bearer token cannot be empty.") + return token + + +def _normalize_message_content(content: Any) -> Optional[str]: + if isinstance(content, str): + content = content.strip() + return content or None + + if isinstance(content, list): + text_parts: List[str] = [] + for item in content: + if isinstance(item, dict): + item_type = (item.get("type") or "").lower() + if item_type in {"text", "input_text"}: + text_value = (item.get("text") or "").strip() + if text_value: + text_parts.append(text_value) + merged = "\n".join(text_parts).strip() + return merged or None + + return None + + +def _extract_text_from_input(input_value: Any) -> Optional[str]: + if isinstance(input_value, str): + text = input_value.strip() + return text or None + + if isinstance(input_value, dict): + return _normalize_message_content(input_value.get("content")) + + if isinstance(input_value, list): + segments: List[str] = [] + for chunk in input_value: + if isinstance(chunk, dict): + content = chunk.get("content") + extracted = _normalize_message_content(content) + if extracted: + segments.append(extracted) + elif isinstance(chunk.get("text"), str): + text_value = chunk["text"].strip() + if text_value: + segments.append(text_value) + merged = "\n".join(segments).strip() + return merged or None + + return None + + +def _extract_agent_identifier(payload: Dict[str, Any]) -> str: + agent_id = payload.get("agent_id") or payload.get("bridge_id") + + if isinstance(agent_id, str): + agent_id = agent_id.strip() + if not agent_id: + raise HTTPException( + status_code=400, + detail="`agent_id` must be included in the request body.", + ) + return str(agent_id) + + +def _build_internal_body(payload: Dict[str, Any]) -> Dict[str, Any]: + agent_id = _extract_agent_identifier(payload) + llm_model = payload.get("model") + + user_message = _extract_text_from_input(payload.get("input")) + + if not user_message: + raise HTTPException(status_code=400, detail="No user message found in payload.") + + configuration = payload.get("configuration") or {} + + if isinstance(llm_model, str) and llm_model.strip(): + configuration.setdefault("model", llm_model.strip()) + + internal_body: Dict[str, Any] = { + "agent_id": agent_id, + "bridge_id": agent_id, + + "user": user_message, + "messages": payload.get("messages", []), + "thread_id": payload.get("conversation_id") + or payload.get("thread_id") or None, + "sub_thread_id": payload.get("sub_thread_id") or None, + "variables": payload.get("variables") or {}, + "configuration": configuration, + "attachments": payload.get("attachments", []), + } + + return internal_body + + +def _override_request_body(request: Request, body: Dict[str, Any]) -> None: + body_bytes = json.dumps(body).encode("utf-8") + request._body = body_bytes # type: ignore[attr-defined] + request._json = body # type: ignore[attr-defined] + request._stream_consumed = True # type: ignore[attr-defined] + if "_form" in request.__dict__: + request.__dict__.pop("_form") + + +def _set_pauthkey_header(request: Request, token: str) -> None: + raw_headers = list(request.scope.get("headers", [])) + filtered_headers = [ + (name, value) + for name, value in raw_headers + if name.lower() != b"authorization" + ] + filtered_headers.append((b"pauthkey", token.encode("utf-8"))) + request.scope["headers"] = filtered_headers + if "_headers" in request.__dict__: + del request.__dict__["_headers"] + + +async def openai_sdk_middleware(request: Request): + payload = await request.json() + internal_body = _build_internal_body(payload) + token = _extract_pauthkey_from_authorization(request) + + _override_request_body(request, internal_body) + _set_pauthkey_header(request, token) + request.state.openai_payload = payload + + await jwt_middleware(request) + await rate_limit(request, key_path="body.bridge_id", points=100) + await rate_limit(request, key_path="body.thread_id", points=20) \ No newline at end of file diff --git a/src/routes/v2/modelRouter.py b/src/routes/v2/modelRouter.py index cfba4525..d5e715b3 100644 --- a/src/routes/v2/modelRouter.py +++ b/src/routes/v2/modelRouter.py @@ -1,17 +1,21 @@ from fastapi import APIRouter, Depends, Request, HTTPException from fastapi.responses import JSONResponse import asyncio +import json from src.services.commonServices.common import chat_multiple_agents, embedding, batch, run_testcases, image, orchestrator_chat from src.services.commonServices.baseService.utils import make_request_data from ...middlewares.middleware import jwt_middleware from ...middlewares.getDataUsingBridgeId import add_configuration_data_to_body +from ...middlewares.openaiSDKmiddleware import openai_sdk_middleware + from concurrent.futures import ThreadPoolExecutor from config import Config from src.services.commonServices.queueService.queueService import queue_obj from src.middlewares.ratelimitMiddleware import rate_limit from models.mongo_connection import db from globals import * +from src.services.utils.openai_sdk_utils import run_openai_chat_and_format router = APIRouter() @@ -58,6 +62,11 @@ async def chat_completion(request: Request, db_config: dict = Depends(add_config return result +@router.post('/openai/responses', dependencies=[Depends(openai_sdk_middleware)]) +async def openai_sdk_responses(request: Request, db_config: dict = Depends(add_configuration_data_to_body)): + return await run_openai_chat_and_format(request, db_config, chat_completion) + + @router.post('/playground/chat/completion/{bridge_id}', dependencies=[Depends(auth_and_rate_limit)]) async def playground_chat_completion_bridge(request: Request, db_config: dict = Depends(add_configuration_data_to_body)): request.state.is_playground = True diff --git a/src/services/utils/common_utils.py b/src/services/utils/common_utils.py index 796c31ea..fa02b6f0 100644 --- a/src/services/utils/common_utils.py +++ b/src/services/utils/common_utils.py @@ -52,11 +52,19 @@ def setup_agent_pre_tools(parsed_data, bridge_configurations): # Get required params from pre_tools_data required_params = pre_tools_data.get('required_params', []) + # Get variables_path mapping for the current agent + variables_path = current_config.get('variables_path', {}).get(pre_tools_data.get('function_name')) + # Build args from agent's own variables args = {} for param in required_params: - if param in agent_variables: - args[param] = agent_variables[param] + # Check if there's a mapping in variables_path for this param + if param in variables_path: + # Get the mapped variable name + mapped_variable = variables_path[param] + # Use the mapped variable to get value from agent_variables + if mapped_variable in agent_variables: + args[param] = agent_variables[mapped_variable] # Update the pre_tools args with agent-specific variables parsed_data['pre_tools']['args'] = args diff --git a/src/services/utils/openai_sdk_utils.py b/src/services/utils/openai_sdk_utils.py new file mode 100644 index 00000000..49dfa5ec --- /dev/null +++ b/src/services/utils/openai_sdk_utils.py @@ -0,0 +1,169 @@ +import json +import time +import uuid +from typing import Any, Awaitable, Callable, Dict, List, Optional + +from fastapi import HTTPException, Request +from fastapi.responses import JSONResponse + + +def _build_output_blocks(message_content: str) -> List[Dict[str, Any]]: + reasoning_block = { + "id": f"rs_{uuid.uuid4().hex}", + "type": "reasoning", + "summary": [], + } + + message_block = { + "id": f"msg_{uuid.uuid4().hex}", + "type": "message", + "status": "completed", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": message_content, + } + ], + } + + return [reasoning_block, message_block] + + +def format_openai_response(chat_response: Dict[str, Any], original_payload: Dict[str, Any] | None) -> Dict[str, Any]: + response_data = chat_response.get("response", {}).get("data", {}) + usage_data = chat_response.get("response", {}).get("usage", {}) or {} + + message_content = response_data.get("content") + if isinstance(message_content, list): + message_content = "\n".join( + chunk.get("text", "") if isinstance(chunk, dict) else str(chunk) + for chunk in message_content + ).strip() + elif not isinstance(message_content, str): + message_content = str(message_content or "") + + message_content = message_content.strip() + finish_reason = response_data.get("finish_reason") or usage_data.get("finish_reason") + model = original_payload.get("model") if isinstance(original_payload, dict) else None + + response_id = f"resp_{uuid.uuid4().hex}" + created_at = int(time.time()) + + return { + "id": response_id, + "object": "response", + "created_at": created_at, + "status": "completed", + "background": False, + "billing": {"payer": "developer"}, + "error": None, + "incomplete_details": None, + "instructions": None, + "max_output_tokens": None, + "max_tool_calls": None, + "model": model, + "output": _build_output_blocks(message_content), + "parallel_tool_calls": True, + "previous_response_id": None, + "prompt_cache_key": None, + "prompt_cache_retention": None, + "reasoning": {"effort": "medium", "summary": None}, + "safety_identifier": None, + "service_tier": "default", + "store": True, + "temperature": original_payload.get("temperature") if isinstance(original_payload, dict) else None, + "text": {"format": {"type": "text"}, "verbosity": "medium"}, + "tool_choice": original_payload.get("tool_choice") if isinstance(original_payload, dict) else "auto", + "tools": original_payload.get("tools") if isinstance(original_payload, dict) else [], + "top_logprobs": 0, + "top_p": original_payload.get("top_p") if isinstance(original_payload, dict) else 1, + "truncation": "disabled", + "usage": { + "input_tokens": usage_data.get("input_tokens"), + "input_tokens_details": { + "cached_tokens": usage_data.get("cached_input_tokens", 0), + }, + "output_tokens": usage_data.get("output_tokens"), + "output_tokens_details": { + "reasoning_tokens": usage_data.get("reasoning_tokens"), + }, + "total_tokens": usage_data.get("total_tokens"), + }, + "user": original_payload.get("user") if isinstance(original_payload, dict) else None, + "metadata": original_payload.get("metadata") if isinstance(original_payload, dict) else {}, + "output_text": message_content, + "finish_reason": finish_reason or "stop", + } + + +def _format_error_detail( + message: str, + error_type: str = "invalid_request_error", + code: Optional[str] = None, + param: Optional[str] = None, +) -> Dict[str, Any]: + return { + "message": message, + "type": error_type, + "param": param, + "code": code + } + + +def _extract_error_message(error_payload: Dict[str, Any]) -> str: + error_value = error_payload.get("error") + + if isinstance(error_value, str): + return error_value + + if isinstance(error_value, dict): + return error_value.get("message") or error_value.get("detail") or json.dumps(error_value) + + if isinstance(error_payload.get("detail"), str): + return error_payload["detail"] + + return json.dumps(error_payload) + + +def _handle_failed_response( + response_payload: Dict[str, Any], + status_code: int = 400, +) -> None: + message = _extract_error_message(response_payload) + error_type = response_payload.get("error_type") or "invalid_request_error" + code = response_payload.get("error_code") + param = response_payload.get("error_param") + + raise HTTPException( + status_code=status_code, + detail=_format_error_detail(message, error_type=error_type, code=code, param=param), + ) + + +async def run_openai_chat_and_format( + request: Request, + db_config: Dict[str, Any], + chat_handler: Callable[[Request, Dict[str, Any]], Awaitable[Any]], +) -> Dict[str, Any]: + openai_payload = getattr(request.state, "openai_payload", {}) + internal_response = await chat_handler(request, db_config) + + if isinstance(internal_response, JSONResponse): + content = internal_response.body + try: + content_dict = json.loads(content) + except Exception: + content_dict = {} + if not content_dict.get("success", True): + status_code = content_dict.get("status_code") or 400 + _handle_failed_response(content_dict, status_code=status_code) + chat_response = content_dict + else: + chat_response = internal_response + + if isinstance(chat_response, dict) and not chat_response.get("success", True): + status_code = chat_response.get("status_code") or 400 + _handle_failed_response(chat_response, status_code=status_code) + + return format_openai_response(chat_response, openai_payload) diff --git a/src/services/utils/update_and_check_cost.py b/src/services/utils/update_and_check_cost.py index 76955b6c..b3220a36 100644 --- a/src/services/utils/update_and_check_cost.py +++ b/src/services/utils/update_and_check_cost.py @@ -127,9 +127,11 @@ async def check_bridge_api_folder_limits(result, bridge_data,version_id): if not isinstance(bridge_data, dict): return None - folder_identifier = result.get('folder_id') + result_data = result if isinstance(result, dict) else {} + + folder_identifier = result_data.get('folder_id') if folder_identifier: - folder_error = await _check_limit(limit_types['folder'],data=result,version_id=version_id) + folder_error = await _check_limit(limit_types['folder'],data=result_data,version_id=version_id) if folder_error: return folder_error @@ -137,12 +139,12 @@ async def check_bridge_api_folder_limits(result, bridge_data,version_id): if bridge_error: return bridge_error - service_identifier = result.get('service') + service_identifier = result_data.get('service') if service_identifier and ( - (result.get('apikeys') and service_identifier in result.get('apikeys', {})) or - (result.get('folder_apikeys') and service_identifier in result.get('folder_apikeys', {})) + (result_data.get('apikeys') and service_identifier in result_data.get('apikeys', {})) or + (result_data.get('folder_apikeys') and service_identifier in result_data.get('folder_apikeys', {})) ): - api_error = await _check_limit(limit_types['apikey'], data=result,version_id=version_id) + api_error = await _check_limit(limit_types['apikey'], data=result_data,version_id=version_id) if api_error: return api_error