From 7c1fa96d16a50627eb701d93ee5d7a6a5fb55373 Mon Sep 17 00:00:00 2001 From: ViaSocket-Git Date: Fri, 4 Apr 2025 17:52:26 +0530 Subject: [PATCH] added validation to chat/completion api --- src/middlewares/requestDataValidation.py | 45 +++++++++++++++ src/routes/v2/modelRouter.py | 3 +- validations/validation.py | 72 +++++++++++++++++++++++- 3 files changed, 116 insertions(+), 4 deletions(-) create mode 100644 src/middlewares/requestDataValidation.py diff --git a/src/middlewares/requestDataValidation.py b/src/middlewares/requestDataValidation.py new file mode 100644 index 00000000..cde4cb55 --- /dev/null +++ b/src/middlewares/requestDataValidation.py @@ -0,0 +1,45 @@ +from fastapi import Depends, HTTPException +from fastapi.requests import Request +from validations.validation import ChatCompletionRequest +from pydantic import ValidationError +from typing import Dict, Any, List,Union + +def get_human_readable_error(exc: Union[ValidationError, List[Dict[str, Any]]]) -> Dict[str, Any]: + """ + Convert validation errors to human-readable format. + Handles both Pydantic ValidationError and raw error lists. + """ + # Get errors list from either source + errors_list = exc.errors() if isinstance(exc, ValidationError) else exc + + errors: List[Dict[str, Any]] = [] + + for error in errors_list: + loc = error.get('loc', ()) + field = '.'.join(str(l) for l in loc if l != '__root__') or 'root' + msg = error.get('msg', 'Invalid value') + errors.append({ + 'message': msg, + 'type': error.get('type', 'validation_error') + }) + + return { + "error": { + "message": "Validation failed", + "details": errors, + "suggestion": "Please check your input values" + } + } + +async def validate_request_data(request: Request): + try: + # Validate request body against Pydantic model + body = await request.json() + validated_data = ChatCompletionRequest(**body) # Validate the body with the Pydantic model + return validated_data + except ValidationError as ve: + # If validation error occurs, format the errors in a human-readable way + error_response = get_human_readable_error(ve.errors()) + raise HTTPException(status_code=400, detail=error_response) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid request data: {e}") diff --git a/src/routes/v2/modelRouter.py b/src/routes/v2/modelRouter.py index 89c76ef8..4be7c421 100644 --- a/src/routes/v2/modelRouter.py +++ b/src/routes/v2/modelRouter.py @@ -8,6 +8,7 @@ from config import Config from src.services.commonServices.queueService.queueService import queue_obj from src.middlewares.ratelimitMiddleware import rate_limit +from ...middlewares.requestDataValidation import validate_request_data router = APIRouter() @@ -19,7 +20,7 @@ async def auth_and_rate_limit(request: Request): await rate_limit(request,key_path='body.bridge_id' , points=100) await rate_limit(request,key_path='body.thread_id', points=20) -@router.post('/chat/completion', dependencies=[Depends(auth_and_rate_limit)]) +@router.post('/chat/completion', dependencies=[Depends(auth_and_rate_limit),Depends(validate_request_data)]) async def chat_completion(request: Request, db_config: dict = Depends(add_configuration_data_to_body)): request.state.is_playground = False request.state.version = 2 diff --git a/validations/validation.py b/validations/validation.py index 8fa61c20..3657ee53 100644 --- a/validations/validation.py +++ b/validations/validation.py @@ -1,9 +1,11 @@ -from pydantic import BaseModel,Field,constr,conint,confloat,validator, ConfigDict -from typing import Optional, Dict,List,Any +from pydantic import BaseModel,Field,constr,conint,confloat,validator, ConfigDict,model_validator +from typing import Optional, Dict,List,Any,Literal from typing_extensions import Annotated +from src.configs.model_configuration import model_config_document + # class ModelConfig(BaseModel): # model: str # creativity_level: Optional[float] = Field(None, ge=0, le=2) @@ -49,4 +51,68 @@ class Bridge_update(BaseModel): name: Optional[str] = None apikey_object_id: Optional[str] = None functionData: Optional[object] - \ No newline at end of file + +class ChatCompletionRequest(BaseModel): + user: str = Field(..., description="User identifier") + bridge_id: str = Field(..., description="Bridge identifier") + variables: Dict[str, Any] = Field(default_factory=dict, description="Template variables") + model: Optional[str] = Field(None, description="Model name (required if configuration is provided)") + response_type: Optional[Literal["text", "json_object"]] = Field(None, description="Response format") + configuration: Optional[Dict[str, Any]] = Field(None, description="Model-specific configuration") + + @model_validator(mode='before') + def validate_configuration(cls, data: Dict[str, Any]) -> Dict[str, Any]: + configuration = data.get("configuration") + model = data.get("model") + + # Skip validation if no configuration or no model + if configuration is None or model is None: + return data + + # Get model config (replace with your actual implementation) + model_config = model_config_document.get(model) + if not model_config: + raise ValueError(f"Model '{model}' not found in configuration document") + + config_schema = model_config.get("configuration", {}) + + for field_name, field_schema in config_schema.items(): + if field_name in configuration: + cls._validate_field( + field_name=field_name, + field_schema=field_schema, + user_value=configuration[field_name] + ) + + return data + + @classmethod + def _validate_field(cls, field_name: str, field_schema: Dict[str, Any], user_value: Any): + """Validate a single configuration field against its schema""" + field_type = field_schema.get("field") + + if field_type == "slider": + if not isinstance(user_value, (int, float)): + raise ValueError(f"{field_name} must be a number") + if "min" in field_schema and user_value < field_schema["min"]: + raise ValueError(f"{field_name} must be ≥ {field_schema['min']}") + if "max" in field_schema and user_value > field_schema["max"]: + raise ValueError(f"{field_name} must be ≤ {field_schema['max']}") + + elif field_type == "boolean": + if not isinstance(user_value, bool): + raise ValueError(f"{field_name} must be a boolean (True/False)") + + elif field_type == "dropdown": + options = field_schema.get("options", []) + if user_value not in options: + raise ValueError(f"{field_name} must be one of: {options}") + + elif field_type == "number": + if not isinstance(user_value, (int, float)): + raise ValueError(f"{field_name} must be a number") + + elif field_type == "select": + options = [opt["type"] for opt in field_schema.get("options", [])] + if user_value not in options: + raise ValueError(f"{field_name} must be one of: {options}") \ No newline at end of file