diff --git a/src/uipath_langchain/agent/tools/context_tool.py b/src/uipath_langchain/agent/tools/context_tool.py index 90012323..5f3f2b0f 100644 --- a/src/uipath_langchain/agent/tools/context_tool.py +++ b/src/uipath_langchain/agent/tools/context_tool.py @@ -1,7 +1,7 @@ """Context tool creation for semantic index retrieval.""" import uuid -from typing import Any +from typing import Any, Optional, Type from langchain_core.documents import Document from langchain_core.tools import StructuredTool @@ -26,6 +26,13 @@ from .utils import sanitize_tool_name +def is_static_query(resource: AgentContextResourceConfig) -> bool: + """Check if the resource configuration uses a static query variant.""" + if resource.settings.query is None or resource.settings.query.variant is None: + return False + return resource.settings.query.variant.lower() == "static" + + def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool: tool_name = sanitize_tool_name(resource.name) retrieval_mode = resource.settings.retrieval_mode.lower() @@ -40,34 +47,58 @@ def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool: def handle_semantic_search( tool_name: str, resource: AgentContextResourceConfig ) -> StructuredTool: + ensure_valid_fields(resource) + + # needed for type checking + assert resource.settings.query is not None + assert resource.settings.query.variant is not None + retriever = ContextGroundingRetriever( index_name=resource.index_name, folder_path=resource.folder_path, number_of_results=resource.settings.result_count, ) - class ContextInputSchemaModel(BaseModel): - query: str = Field( - ..., description="The query to search for in the knowledge base" - ) - class ContextOutputSchemaModel(BaseModel): documents: list[Document] = Field( ..., description="List of retrieved documents." ) - input_model = ContextInputSchemaModel output_model = ContextOutputSchemaModel - @mockable( - name=resource.name, - description=resource.description, - input_schema=input_model.model_json_schema(), - output_schema=output_model.model_json_schema(), - example_calls=[], # Examples cannot be provided for context. - ) - async def context_tool_fn(query: str) -> dict[str, Any]: - return {"documents": await retriever.ainvoke(query)} + if is_static_query(resource): + static_query_value = resource.settings.query.value + assert static_query_value is not None + input_model = None + + @mockable( + name=resource.name, + description=resource.description, + input_schema=input_model, + output_schema=output_model.model_json_schema(), + example_calls=[], # Examples cannot be provided for context. + ) + async def context_tool_fn() -> dict[str, Any]: + return {"documents": await retriever.ainvoke(static_query_value)} + + else: + # Dynamic query - requires query parameter + class ContextInputSchemaModel(BaseModel): + query: str = Field( + ..., description="The query to search for in the knowledge base" + ) + + input_model = ContextInputSchemaModel + + @mockable( + name=resource.name, + description=resource.description, + input_schema=input_model.model_json_schema(), + output_schema=output_model.model_json_schema(), + example_calls=[], # Examples cannot be provided for context. + ) + async def context_tool_fn(query: str) -> dict[str, Any]: + return {"documents": await retriever.ainvoke(query)} return StructuredToolWithOutputType( name=tool_name, @@ -82,36 +113,69 @@ def handle_deep_rag( tool_name: str, resource: AgentContextResourceConfig ) -> StructuredTool: ensure_valid_fields(resource) + # needed for type checking assert resource.settings.query is not None - assert resource.settings.query.value is not None + assert resource.settings.query.variant is not None index_name = resource.index_name - prompt = resource.settings.query.value if not resource.settings.citation_mode: raise ValueError("Citation mode is required for Deep RAG") citation_mode = CitationMode(resource.settings.citation_mode.value) - input_model = None output_model = DeepRagResponse - @mockable( - name=resource.name, - description=resource.description, - input_schema=input_model, - output_schema=output_model.model_json_schema(), - example_calls=[], # Examples cannot be provided for context. - ) - async def context_tool_fn() -> dict[str, Any]: - # TODO: add glob pattern support - return interrupt( - CreateDeepRag( - name=f"task-{uuid.uuid4()}", - index_name=index_name, - prompt=prompt, - citation_mode=citation_mode, + if is_static_query(resource): + # Static query - no input parameter needed + static_prompt = resource.settings.query.value + assert static_prompt is not None + input_model = None + + @mockable( + name=resource.name, + description=resource.description, + input_schema=input_model, + output_schema=output_model.model_json_schema(), + example_calls=[], # Examples cannot be provided for context. + ) + async def context_tool_fn() -> dict[str, Any]: + # TODO: add glob pattern support + return interrupt( + CreateDeepRag( + name=f"task-{uuid.uuid4()}", + index_name=index_name, + prompt=static_prompt, + citation_mode=citation_mode, + ) + ) + + else: + # Dynamic query - requires query parameter + class DeepRagInputSchemaModel(BaseModel): + query: str = Field( + ..., + description="Describe the task: what to research across documents, what to synthesize, and how to cite sources", ) + + input_model = DeepRagInputSchemaModel + + @mockable( + name=resource.name, + description=resource.description, + input_schema=input_model.model_json_schema(), + output_schema=output_model.model_json_schema(), + example_calls=[], # Examples cannot be provided for context. ) + async def context_tool_fn(query: str) -> dict[str, Any]: + # TODO: add glob pattern support + return interrupt( + CreateDeepRag( + name=f"task-{uuid.uuid4()}", + index_name=index_name, + prompt=query, + citation_mode=citation_mode, + ) + ) return StructuredToolWithOutputType( name=tool_name, @@ -129,11 +193,9 @@ def handle_batch_transform( # needed for type checking assert resource.settings.query is not None - assert resource.settings.query.value is not None + assert resource.settings.query.variant is not None index_name = resource.index_name - prompt = resource.settings.query.value - index_folder_path = resource.folder_path if not resource.settings.web_search_grounding: raise ValueError("Web search grounding field is required for Batch Transform") @@ -157,35 +219,82 @@ def handle_batch_transform( ) ) - class BatchTransformSchemaModel(BaseModel): - destination_path: str = Field( - ..., - description="The relative file path destination for the modified csv file", - ) - - input_model = BatchTransformSchemaModel output_model = BatchTransformResponse - @mockable( - name=resource.name, - description=resource.description, - input_schema=input_model.model_json_schema(), - output_schema=output_model.model_json_schema(), - example_calls=[], # Examples cannot be provided for context. - ) - async def context_tool_fn(destination_path: str) -> dict[str, Any]: - # TODO: storage_bucket_folder_path_prefix support - return interrupt( - CreateBatchTransform( - name=f"task-{uuid.uuid4()}", - index_name=index_name, - prompt=prompt, - destination_path=destination_path, - index_folder_path=index_folder_path, - enable_web_search_grounding=enable_web_search_grounding, - output_columns=batch_transform_output_columns, + input_model: Optional[Type[BaseModel]] + + if is_static_query(resource): + # Static query - only destination_path parameter needed + static_prompt = resource.settings.query.value + assert static_prompt is not None + + class StaticBatchTransformSchemaModel(BaseModel): + destination_path: str = Field( + default="output.csv", + description="The relative file path destination for the modified csv file", + ) + + input_model = StaticBatchTransformSchemaModel + + @mockable( + name=resource.name, + description=resource.description, + input_schema=input_model.model_json_schema(), + output_schema=output_model.model_json_schema(), + example_calls=[], # Examples cannot be provided for context. + ) + async def context_tool_fn( + destination_path: str = "output.csv", + ) -> dict[str, Any]: + # TODO: storage_bucket_folder_path_prefix support + return interrupt( + CreateBatchTransform( + name=f"task-{uuid.uuid4()}", + index_name=index_name, + prompt=static_prompt, + destination_path=destination_path, + index_folder_path=index_folder_path, + enable_web_search_grounding=enable_web_search_grounding, + output_columns=batch_transform_output_columns, + ) + ) + + else: + # Dynamic query - requires both query and destination_path parameters + class DynamicBatchTransformSchemaModel(BaseModel): + query: str = Field( + ..., + description="Describe the task for each row: what to analyze, what to extract, and how to populate the output columns", ) + destination_path: str = Field( + default="output.csv", + description="The relative file path destination for the modified csv file", + ) + + input_model = DynamicBatchTransformSchemaModel + + @mockable( + name=resource.name, + description=resource.description, + input_schema=input_model.model_json_schema(), + output_schema=output_model.model_json_schema(), + example_calls=[], # Examples cannot be provided for context. ) + async def context_tool_fn( + query: str, destination_path: str = "output.csv" + ) -> dict[str, Any]: + # TODO: storage_bucket_folder_path_prefix support + return interrupt( + CreateBatchTransform( + name=f"task-{uuid.uuid4()}", + index_name=index_name, + prompt=query, + destination_path=destination_path, + index_folder_path=index_folder_path, + enable_web_search_grounding=enable_web_search_grounding, + output_columns=batch_transform_output_columns, + ) + ) return StructuredToolWithOutputType( name=tool_name, @@ -199,5 +308,9 @@ async def context_tool_fn(destination_path: str) -> dict[str, Any]: def ensure_valid_fields(resource_config: AgentContextResourceConfig): if not resource_config.settings.query: raise ValueError("Query object is required") - if not resource_config.settings.query.value: - raise ValueError("Query prompt is required") + + if not resource_config.settings.query.variant: + raise ValueError("Query variant is required") + + if is_static_query(resource_config) and not resource_config.settings.query.value: + raise ValueError("Static query requires a query value to be set") diff --git a/tests/agent/tools/test_context_tool.py b/tests/agent/tools/test_context_tool.py index 5373bbd6..f2e18079 100644 --- a/tests/agent/tools/test_context_tool.py +++ b/tests/agent/tools/test_context_tool.py @@ -5,16 +5,22 @@ import pytest from langchain_core.documents import Document from uipath.agent.models.agent import ( + AgentContextOutputColumn, AgentContextQuerySetting, AgentContextResourceConfig, AgentContextRetrievalMode, AgentContextSettings, AgentContextValueSetting, ) -from uipath.platform.context_grounding import CitationMode, DeepRagResponse +from uipath.platform.context_grounding import ( + BatchTransformResponse, + CitationMode, + DeepRagResponse, +) from uipath_langchain.agent.tools.context_tool import ( create_context_tool, + handle_batch_transform, handle_deep_rag, handle_semantic_search, ) @@ -36,6 +42,7 @@ def _create_config( index_name="test-index", folder_path="/test/folder", query_value=None, + query_variant="static", citation_mode_value=None, retrieval_mode=AgentContextRetrievalMode.SEMANTIC, ): @@ -51,7 +58,7 @@ def _create_config( query=AgentContextQuerySetting( value=query_value, description="some description", - variant="variant", + variant=query_variant, ), citation_mode=citation_mode_value, ), @@ -83,12 +90,21 @@ def test_missing_query_object_raises_error(self, base_resource_config): with pytest.raises(ValueError, match="Query object is required"): handle_deep_rag("test_deep_rag", resource) - def test_missing_query_value_raises_error(self, base_resource_config): - """Test that missing query.value raises ValueError.""" - resource = base_resource_config() - resource.settings.query.value = None + def test_missing_static_query_value_raises_error(self, base_resource_config): + """Test that missing query.value for static variant raises ValueError.""" + resource = base_resource_config(query_variant="static", query_value=None) + + with pytest.raises( + ValueError, match="Static query requires a query value to be set" + ): + handle_deep_rag("test_deep_rag", resource) + + def test_missing_query_variant_raises_error(self, base_resource_config): + """Test that missing query.variant raises ValueError.""" + resource = base_resource_config(query_value="some query") + resource.settings.query.variant = None - with pytest.raises(ValueError, match="Query prompt is required"): + with pytest.raises(ValueError, match="Query variant is required"): handle_deep_rag("test_deep_rag", resource) def test_missing_citation_mode_raises_error(self, base_resource_config): @@ -197,6 +213,60 @@ async def test_unique_task_names_on_multiple_invocations( # Verify all have task- prefix assert all(name.startswith("task-") for name in task_names) + def test_dynamic_query_deep_rag_creation(self, base_resource_config): + """Test successful creation of Deep RAG tool with dynamic query.""" + resource = base_resource_config( + query_variant="dynamic", + query_value=None, + citation_mode_value=AgentContextValueSetting(value="Inline"), + ) + + result = handle_deep_rag("test_deep_rag", resource) + + assert isinstance(result, StructuredToolWithOutputType) + assert result.name == "test_deep_rag" + assert result.description == "Test Deep RAG tool" + assert result.args_schema is not None # Dynamic has input schema + assert result.output_type == DeepRagResponse + + def test_dynamic_query_deep_rag_has_query_parameter(self, base_resource_config): + """Test that dynamic Deep RAG tool has query parameter in schema.""" + resource = base_resource_config( + query_variant="dynamic", + query_value=None, + citation_mode_value=AgentContextValueSetting(value="Inline"), + ) + + result = handle_deep_rag("test_deep_rag", resource) + + # Check that the input schema has a query field + assert result.args_schema is not None + assert hasattr(result.args_schema, "model_json_schema") + schema = result.args_schema.model_json_schema() + assert "properties" in schema + assert "query" in schema["properties"] + assert schema["properties"]["query"]["type"] == "string" + + @pytest.mark.asyncio + async def test_dynamic_query_uses_provided_query(self, base_resource_config): + """Test that dynamic query variant uses the query parameter provided at runtime.""" + resource = base_resource_config( + query_variant="dynamic", + query_value=None, + citation_mode_value=AgentContextValueSetting(value="Inline"), + ) + tool = handle_deep_rag("test_tool", resource) + + with patch( + "uipath_langchain.agent.tools.context_tool.interrupt" + ) as mock_interrupt: + mock_interrupt.return_value = {"mocked": "response"} + assert tool.coroutine is not None + await tool.coroutine(query="runtime provided query") + + call_args = mock_interrupt.call_args[0][0] + assert call_args.prompt == "runtime provided query" + class TestCreateContextTool: """Test cases for create_context_tool function.""" @@ -213,6 +283,11 @@ def semantic_search_config(self): settings=AgentContextSettings( result_count=10, retrieval_mode=AgentContextRetrievalMode.SEMANTIC, + query=AgentContextQuerySetting( + value=None, + description="Query for semantic search", + variant="dynamic", + ), ), is_enabled=True, ) @@ -284,6 +359,11 @@ def semantic_config(self): settings=AgentContextSettings( result_count=5, retrieval_mode=AgentContextRetrievalMode.SEMANTIC, + query=AgentContextQuerySetting( + value=None, + description="Query for semantic search", + variant="dynamic", + ), ), is_enabled=True, ) @@ -335,3 +415,316 @@ async def test_semantic_search_returns_documents(self, semantic_config): assert "documents" in result assert len(result["documents"]) == 2 assert result["documents"][0].page_content == "Test content 1" + + def test_static_query_semantic_search_creation(self): + """Test successful creation of semantic search tool with static query.""" + resource = AgentContextResourceConfig( + name="semantic_tool", + description="Semantic search tool", + resource_type="context", + index_name="test-index", + folder_path="/test/folder", + settings=AgentContextSettings( + result_count=5, + retrieval_mode=AgentContextRetrievalMode.SEMANTIC, + query=AgentContextQuerySetting( + value="predefined static query", + description="Static query for semantic search", + variant="static", + ), + ), + is_enabled=True, + ) + + result = handle_semantic_search("semantic_tool", resource) + + assert isinstance(result, StructuredToolWithOutputType) + assert result.name == "semantic_tool" + assert result.description == "Semantic search tool" + assert result.args_schema is None # Static has no input schema + + @pytest.mark.asyncio + async def test_static_query_uses_predefined_query(self): + """Test that static query variant uses the predefined query value.""" + resource = AgentContextResourceConfig( + name="semantic_tool", + description="Semantic search tool", + resource_type="context", + index_name="test-index", + folder_path="/test/folder", + settings=AgentContextSettings( + result_count=5, + retrieval_mode=AgentContextRetrievalMode.SEMANTIC, + query=AgentContextQuerySetting( + value="predefined static query", + description="Static query for semantic search", + variant="static", + ), + ), + is_enabled=True, + ) + + mock_documents = [ + Document(page_content="Test content", metadata={"source": "doc1"}), + ] + + with patch( + "uipath_langchain.agent.tools.context_tool.ContextGroundingRetriever" + ) as mock_retriever_class: + mock_retriever = AsyncMock() + mock_retriever.ainvoke.return_value = mock_documents + mock_retriever_class.return_value = mock_retriever + + tool = handle_semantic_search("semantic_tool", resource) + assert tool.coroutine is not None + result = await tool.coroutine() + + # Verify the retriever was called with the static query value + mock_retriever.ainvoke.assert_called_once_with("predefined static query") + assert "documents" in result + assert len(result["documents"]) == 1 + + +class TestHandleBatchTransform: + """Test cases for handle_batch_transform function.""" + + @pytest.fixture + def batch_transform_config(self): + """Fixture for batch transform configuration with static query.""" + return AgentContextResourceConfig( + name="batch_transform_tool", + description="Batch transform tool", + resource_type="context", + index_name="test-index", + folder_path="/test/folder", + settings=AgentContextSettings( + result_count=5, + retrieval_mode=AgentContextRetrievalMode.BATCH_TRANSFORM, + query=AgentContextQuerySetting( + value="transform this data", + description="Static query for batch transform", + variant="static", + ), + web_search_grounding=AgentContextValueSetting(value="enabled"), + output_columns=[ + AgentContextOutputColumn( + name="output_col1", description="First output column" + ), + AgentContextOutputColumn( + name="output_col2", description="Second output column" + ), + ], + ), + is_enabled=True, + ) + + def test_static_query_batch_transform_creation(self, batch_transform_config): + """Test successful creation of batch transform tool with static query.""" + result = handle_batch_transform("batch_transform_tool", batch_transform_config) + + assert isinstance(result, StructuredToolWithOutputType) + assert result.name == "batch_transform_tool" + assert result.description == "Batch transform tool" + assert result.args_schema is not None # Has destination_path parameter + assert result.output_type == BatchTransformResponse + + def test_static_query_batch_transform_has_destination_path_only( + self, batch_transform_config + ): + """Test that static batch transform only has destination_path in schema.""" + result = handle_batch_transform("batch_transform_tool", batch_transform_config) + + assert result.args_schema is not None + assert hasattr(result.args_schema, "model_json_schema") + schema = result.args_schema.model_json_schema() + assert "properties" in schema + assert "destination_path" in schema["properties"] + assert "query" not in schema["properties"] # No query for static + + def test_dynamic_query_batch_transform_creation(self): + """Test successful creation of batch transform tool with dynamic query.""" + resource = AgentContextResourceConfig( + name="batch_transform_tool", + description="Batch transform tool", + resource_type="context", + index_name="test-index", + folder_path="/test/folder", + settings=AgentContextSettings( + result_count=5, + retrieval_mode=AgentContextRetrievalMode.BATCH_TRANSFORM, + query=AgentContextQuerySetting( + value=None, + description="Dynamic query for batch transform", + variant="dynamic", + ), + web_search_grounding=AgentContextValueSetting(value="enabled"), + output_columns=[ + AgentContextOutputColumn( + name="output_col1", description="First output column" + ), + ], + ), + is_enabled=True, + ) + + result = handle_batch_transform("batch_transform_tool", resource) + + assert isinstance(result, StructuredToolWithOutputType) + assert result.name == "batch_transform_tool" + assert result.args_schema is not None + assert result.output_type == BatchTransformResponse + + def test_dynamic_query_batch_transform_has_both_parameters(self): + """Test that dynamic batch transform has both query and destination_path.""" + resource = AgentContextResourceConfig( + name="batch_transform_tool", + description="Batch transform tool", + resource_type="context", + index_name="test-index", + folder_path="/test/folder", + settings=AgentContextSettings( + result_count=5, + retrieval_mode=AgentContextRetrievalMode.BATCH_TRANSFORM, + query=AgentContextQuerySetting( + value=None, + description="Dynamic query for batch transform", + variant="dynamic", + ), + web_search_grounding=AgentContextValueSetting(value="enabled"), + output_columns=[ + AgentContextOutputColumn( + name="output_col1", description="First output column" + ), + ], + ), + is_enabled=True, + ) + + result = handle_batch_transform("batch_transform_tool", resource) + + assert result.args_schema is not None + assert hasattr(result.args_schema, "model_json_schema") + schema = result.args_schema.model_json_schema() + assert "properties" in schema + assert "query" in schema["properties"] + assert "destination_path" in schema["properties"] + + @pytest.mark.asyncio + async def test_static_query_batch_transform_uses_predefined_query( + self, batch_transform_config + ): + """Test that static query variant uses the predefined query value.""" + tool = handle_batch_transform("batch_transform_tool", batch_transform_config) + + with patch( + "uipath_langchain.agent.tools.context_tool.interrupt" + ) as mock_interrupt: + mock_interrupt.return_value = {"mocked": "response"} + assert tool.coroutine is not None + await tool.coroutine(destination_path="/output/result.csv") + + call_args = mock_interrupt.call_args[0][0] + assert call_args.prompt == "transform this data" + assert call_args.destination_path == "/output/result.csv" + + @pytest.mark.asyncio + async def test_dynamic_query_batch_transform_uses_provided_query(self): + """Test that dynamic query variant uses the query parameter provided at runtime.""" + resource = AgentContextResourceConfig( + name="batch_transform_tool", + description="Batch transform tool", + resource_type="context", + index_name="test-index", + folder_path="/test/folder", + settings=AgentContextSettings( + result_count=5, + retrieval_mode=AgentContextRetrievalMode.BATCH_TRANSFORM, + query=AgentContextQuerySetting( + value=None, + description="Dynamic query for batch transform", + variant="dynamic", + ), + web_search_grounding=AgentContextValueSetting(value="enabled"), + output_columns=[ + AgentContextOutputColumn( + name="output_col1", description="First output column" + ), + ], + ), + is_enabled=True, + ) + + tool = handle_batch_transform("batch_transform_tool", resource) + + with patch( + "uipath_langchain.agent.tools.context_tool.interrupt" + ) as mock_interrupt: + mock_interrupt.return_value = {"mocked": "response"} + assert tool.coroutine is not None + await tool.coroutine( + query="runtime provided query", destination_path="/output/result.csv" + ) + + call_args = mock_interrupt.call_args[0][0] + assert call_args.prompt == "runtime provided query" + assert call_args.destination_path == "/output/result.csv" + + @pytest.mark.asyncio + async def test_static_query_batch_transform_uses_default_destination_path( + self, batch_transform_config + ): + """Test that static batch transform uses default destination_path when not provided.""" + tool = handle_batch_transform("batch_transform_tool", batch_transform_config) + + with patch( + "uipath_langchain.agent.tools.context_tool.interrupt" + ) as mock_interrupt: + mock_interrupt.return_value = {"mocked": "response"} + assert tool.coroutine is not None + # Call without providing destination_path + await tool.coroutine() + + call_args = mock_interrupt.call_args[0][0] + assert call_args.prompt == "transform this data" + assert call_args.destination_path == "output.csv" + + @pytest.mark.asyncio + async def test_dynamic_query_batch_transform_uses_default_destination_path(self): + """Test that dynamic batch transform uses default destination_path when not provided.""" + resource = AgentContextResourceConfig( + name="batch_transform_tool", + description="Batch transform tool", + resource_type="context", + index_name="test-index", + folder_path="/test/folder", + settings=AgentContextSettings( + result_count=5, + retrieval_mode=AgentContextRetrievalMode.BATCH_TRANSFORM, + query=AgentContextQuerySetting( + value=None, + description="Dynamic query for batch transform", + variant="dynamic", + ), + web_search_grounding=AgentContextValueSetting(value="enabled"), + output_columns=[ + AgentContextOutputColumn( + name="output_col1", description="First output column" + ), + ], + ), + is_enabled=True, + ) + + tool = handle_batch_transform("batch_transform_tool", resource) + + with patch( + "uipath_langchain.agent.tools.context_tool.interrupt" + ) as mock_interrupt: + mock_interrupt.return_value = {"mocked": "response"} + assert tool.coroutine is not None + # Call with only query, no destination_path + await tool.coroutine(query="runtime provided query") + + call_args = mock_interrupt.call_args[0][0] + assert call_args.prompt == "runtime provided query" + assert call_args.destination_path == "output.csv"