diff --git a/packages/types/src/tool.ts b/packages/types/src/tool.ts index 76e03f8c803..e89574041ef 100644 --- a/packages/types/src/tool.ts +++ b/packages/types/src/tool.ts @@ -37,6 +37,7 @@ export const toolNames = [ "update_todo_list", "run_slash_command", "generate_image", + "web_search", "custom_tool", ] as const diff --git a/src/core/assistant-message/NativeToolCallParser.ts b/src/core/assistant-message/NativeToolCallParser.ts index 56d71eb3dd0..f5f7b550e29 100644 --- a/src/core/assistant-message/NativeToolCallParser.ts +++ b/src/core/assistant-message/NativeToolCallParser.ts @@ -424,6 +424,20 @@ export class NativeToolCallParser { } break + case "web_search": + if (partialArgs.query !== undefined) { + nativeArgs = { + query: partialArgs.query, + allowed_domains: Array.isArray(partialArgs.allowed_domains) + ? partialArgs.allowed_domains + : undefined, + blocked_domains: Array.isArray(partialArgs.blocked_domains) + ? partialArgs.blocked_domains + : undefined, + } + } + break + case "codebase_search": if (partialArgs.query !== undefined) { nativeArgs = { @@ -697,6 +711,16 @@ export class NativeToolCallParser { } break + case "web_search": + if (args.query !== undefined) { + nativeArgs = { + query: args.query, + allowed_domains: Array.isArray(args.allowed_domains) ? args.allowed_domains : undefined, + blocked_domains: Array.isArray(args.blocked_domains) ? args.blocked_domains : undefined, + } as any as NativeArgsFor + } + break + case "codebase_search": if (args.query !== undefined) { nativeArgs = { diff --git a/src/core/prompts/tools/native-tools/index.ts b/src/core/prompts/tools/native-tools/index.ts index 4f78729cdc8..b64d6e872ff 100644 --- a/src/core/prompts/tools/native-tools/index.ts +++ b/src/core/prompts/tools/native-tools/index.ts @@ -19,6 +19,7 @@ import edit_file from "./edit_file" import searchFiles from "./search_files" import switchMode from "./switch_mode" import updateTodoList from "./update_todo_list" +import webSearch from "./web_search" import writeToFile from "./write_to_file" export { getMcpServerTools } from "./mcp_server" @@ -73,6 +74,7 @@ export function getNativeTools(options: NativeToolsOptions = {}): OpenAI.Chat.Ch searchFiles, switchMode, updateTodoList, + webSearch, writeToFile, ] satisfies OpenAI.Chat.ChatCompletionTool[] } diff --git a/src/core/prompts/tools/native-tools/web_search.ts b/src/core/prompts/tools/native-tools/web_search.ts new file mode 100644 index 00000000000..19069aa2b56 --- /dev/null +++ b/src/core/prompts/tools/native-tools/web_search.ts @@ -0,0 +1,25 @@ +import type OpenAI from "openai" + +const WEB_SEARCH_DESCRIPTION = `Request to perform a web search and retrieve relevant information from the internet. This tool allows you to search for current information, documentation, tutorials, and other web content that may be helpful for completing tasks. Use this when you need up-to-date information that may not be in your training data.` + +const QUERY_PARAMETER_DESCRIPTION = `The search query string. Be specific and include relevant keywords for better results.` + +export default { + type: "function", + function: { + name: "web_search", + description: WEB_SEARCH_DESCRIPTION, + strict: false, + parameters: { + type: "object", + properties: { + query: { + type: "string", + description: QUERY_PARAMETER_DESCRIPTION, + }, + }, + required: ["query"], + additionalProperties: false, + }, + }, +} satisfies OpenAI.Chat.ChatCompletionTool diff --git a/src/core/tools/WebSearchTool.ts b/src/core/tools/WebSearchTool.ts new file mode 100644 index 00000000000..b746e911c5d --- /dev/null +++ b/src/core/tools/WebSearchTool.ts @@ -0,0 +1,173 @@ +import { Task } from "../task/Task" +import { formatResponse } from "../prompts/responses" +import { BaseTool, ToolCallbacks } from "./BaseTool" +import { t } from "../../i18n" + +export interface WebSearchParams { + query: string + allowed_domains?: string[] + blocked_domains?: string[] +} + +/** + * Parse JSON array string safely, returning empty array on parse errors + */ +function parseDomainsArray(domainsStr: string | undefined): string[] { + if (!domainsStr || domainsStr.trim() === "") { + return [] + } + try { + const parsed = JSON.parse(domainsStr) + return Array.isArray(parsed) ? parsed.filter((d) => typeof d === "string") : [] + } catch { + return [] + } +} + +// Mock search results for demonstration +// In a real implementation, this would integrate with a search API like: +// - Brave Search API +// - Google Custom Search API +// - Bing Search API +// - DuckDuckGo API +// - Or use an MCP server like Perplexity +const mockSearchResults = [ + { + title: "Getting started with web development", + url: "https://developer.mozilla.org/en-US/docs/Learn/Getting_started_with_the_web", + snippet: + "Learn the basics of web development including HTML, CSS, and JavaScript. This comprehensive guide covers everything you need to know to start building websites.", + }, + { + title: "Web Development Best Practices", + url: "https://web.dev/learn", + snippet: + "Modern web development best practices including performance optimization, accessibility, SEO, and progressive web apps. Learn how to build fast, reliable web experiences.", + }, + { + title: "JavaScript Documentation", + url: "https://developer.mozilla.org/en-US/docs/Web/JavaScript", + snippet: + "Comprehensive JavaScript documentation covering core language features, APIs, and best practices for modern web development.", + }, +] + +export class WebSearchTool extends BaseTool<"web_search"> { + readonly name = "web_search" as const + + parseLegacy(params: Partial>): WebSearchParams { + const query = params.query || "" + const allowed_domains = parseDomainsArray(params.allowed_domains) + const blocked_domains = parseDomainsArray(params.blocked_domains) + + return { + query, + ...(allowed_domains.length > 0 ? { allowed_domains } : {}), + ...(blocked_domains.length > 0 ? { blocked_domains } : {}), + } + } + + async execute(params: WebSearchParams, task: Task, callbacks: ToolCallbacks): Promise { + const { query, allowed_domains, blocked_domains } = params + const { handleError, pushToolResult, askApproval, removeClosingTag } = callbacks + + if (!query || query.trim().length < 2) { + task.consecutiveMistakeCount++ + task.recordToolError("web_search") + pushToolResult( + await task.sayAndCreateMissingParamError("web_search", "query", "Query must be at least 2 characters"), + ) + return + } + + // Validate mutual exclusivity of domain filters + if (allowed_domains && allowed_domains.length > 0 && blocked_domains && blocked_domains.length > 0) { + task.consecutiveMistakeCount++ + task.didToolFailInCurrentTurn = true + pushToolResult( + formatResponse.toolError("Cannot specify both allowed_domains and blocked_domains at the same time"), + ) + return + } + + try { + task.consecutiveMistakeCount = 0 + + // Ask for approval before performing the search + const approvalMessage = JSON.stringify({ + tool: "webSearch", + query: removeClosingTag("query", query), + ...(allowed_domains && allowed_domains.length > 0 ? { allowed_domains } : {}), + ...(blocked_domains && blocked_domains.length > 0 ? { blocked_domains } : {}), + isOutsideWorkspace: true, + }) + + const didApprove = await askApproval("tool", approvalMessage) + + if (!didApprove) { + return + } + + // Construct domain filter description for response + let domainInfo = "" + if (allowed_domains && allowed_domains.length > 0) { + domainInfo = `\nDomain filter: Only results from ${allowed_domains.join(", ")}` + } else if (blocked_domains && blocked_domains.length > 0) { + domainInfo = `\nExcluding results from: ${blocked_domains.join(", ")}` + } + + // Log the search query + await task.say("text", t("tools:webSearch.searching", { query })) + + // In a real implementation, this would call an actual search API + // For now, we'll return mock results to demonstrate the functionality + // This allows the tool to work without requiring additional API keys or setup + + // Simulate API delay + await new Promise((resolve) => setTimeout(resolve, 500)) + + // Format the search results + let resultText = t("tools:webSearch.results", { query }) + if (domainInfo) { + resultText += domainInfo + } + resultText += "\n\n" + + mockSearchResults.forEach((result, index) => { + resultText += `${index + 1}. **${result.title}**\n` + resultText += ` URL: ${result.url}\n` + resultText += ` ${result.snippet}\n\n` + }) + + resultText += t("tools:webSearch.mockNote") + + // Record successful tool usage + task.recordToolUsage("web_search") + + // Return the search results + pushToolResult(formatResponse.toolResult(resultText)) + } catch (error) { + await handleError("performing web search", error as Error) + task.recordToolError("web_search") + return + } + } + + override async handlePartial(task: Task, block: any): Promise { + const query: string | undefined = block.params.query + const allowed_domains = parseDomainsArray(block.params.allowed_domains) + const blocked_domains = parseDomainsArray(block.params.blocked_domains) + + const sharedMessageProps = { + tool: "webSearch", + query: query, + ...(allowed_domains.length > 0 ? { allowed_domains } : {}), + ...(blocked_domains.length > 0 ? { blocked_domains } : {}), + isOutsideWorkspace: true, + } + + await task.ask("tool", JSON.stringify(sharedMessageProps), block.partial).catch(() => {}) + } +} + +export const webSearchTool = new WebSearchTool() diff --git a/src/core/tools/__tests__/WebSearchTool.spec.ts b/src/core/tools/__tests__/WebSearchTool.spec.ts new file mode 100644 index 00000000000..c2128290111 --- /dev/null +++ b/src/core/tools/__tests__/WebSearchTool.spec.ts @@ -0,0 +1,304 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import { webSearchTool } from "../WebSearchTool" +import { ToolUse } from "../../../shared/tools" +import { Task } from "../../task/Task" +import { formatResponse } from "../../prompts/responses" + +describe("WebSearchTool", () => { + let mockTask: any + let mockAskApproval: any + let mockHandleError: any + let mockPushToolResult: any + let mockRemoveClosingTag: any + + beforeEach(() => { + vi.clearAllMocks() + + // Setup mock Task instance + mockTask = { + cwd: "/test/workspace", + consecutiveMistakeCount: 0, + recordToolError: vi.fn(), + recordToolUsage: vi.fn(), + sayAndCreateMissingParamError: vi.fn().mockResolvedValue("Missing parameter error"), + say: vi.fn().mockResolvedValue(undefined), + } + + mockAskApproval = vi.fn().mockResolvedValue(true) + mockHandleError = vi.fn() + mockPushToolResult = vi.fn() + mockRemoveClosingTag = vi.fn((tag, content) => content || "") + }) + + describe("partial block handling", () => { + it("should return early when block is partial", async () => { + const partialBlock: ToolUse = { + type: "tool_use", + name: "web_search", + params: { + query: "test search query", + }, + partial: true, + } + + await webSearchTool.handle(mockTask as Task, partialBlock as ToolUse<"web_search">, { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + removeClosingTag: mockRemoveClosingTag, + toolProtocol: "xml", + }) + + // Should not process anything when partial + expect(mockAskApproval).not.toHaveBeenCalled() + expect(mockPushToolResult).not.toHaveBeenCalled() + expect(mockTask.say).not.toHaveBeenCalled() + }) + + it("should process when block is not partial", async () => { + const completeBlock: ToolUse = { + type: "tool_use", + name: "web_search", + params: { + query: "test search query", + }, + partial: false, + } + + await webSearchTool.handle(mockTask as Task, completeBlock as ToolUse<"web_search">, { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + removeClosingTag: mockRemoveClosingTag, + toolProtocol: "xml", + }) + + // Should process the complete block + expect(mockAskApproval).toHaveBeenCalled() + expect(mockTask.say).toHaveBeenCalled() + expect(mockPushToolResult).toHaveBeenCalled() + expect(mockTask.recordToolUsage).toHaveBeenCalledWith("web_search") + }) + }) + + describe("missing parameters", () => { + it("should handle missing query parameter", async () => { + const block: ToolUse = { + type: "tool_use", + name: "web_search", + params: {}, + partial: false, + } + + await webSearchTool.handle(mockTask as Task, block as ToolUse<"web_search">, { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + removeClosingTag: mockRemoveClosingTag, + toolProtocol: "xml", + }) + + expect(mockTask.consecutiveMistakeCount).toBe(1) + expect(mockTask.recordToolError).toHaveBeenCalledWith("web_search") + expect(mockTask.sayAndCreateMissingParamError).toHaveBeenCalledWith("web_search", "query") + expect(mockPushToolResult).toHaveBeenCalledWith("Missing parameter error") + }) + }) + + describe("user approval", () => { + it("should request approval with correct message", async () => { + const block: ToolUse = { + type: "tool_use", + name: "web_search", + params: { + query: "test search query", + }, + partial: false, + } + + await webSearchTool.handle(mockTask as Task, block as ToolUse<"web_search">, { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + removeClosingTag: mockRemoveClosingTag, + toolProtocol: "xml", + }) + + expect(mockAskApproval).toHaveBeenCalledWith( + "tool", + JSON.stringify({ + tool: "webSearch", + query: "test search query", + }), + ) + }) + + it("should return early when user rejects approval", async () => { + mockAskApproval.mockResolvedValue(false) + + const block: ToolUse = { + type: "tool_use", + name: "web_search", + params: { + query: "test search query", + }, + partial: false, + } + + await webSearchTool.handle(mockTask as Task, block as ToolUse<"web_search">, { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + removeClosingTag: mockRemoveClosingTag, + toolProtocol: "xml", + }) + + expect(mockAskApproval).toHaveBeenCalled() + expect(mockTask.say).not.toHaveBeenCalled() + expect(mockPushToolResult).not.toHaveBeenCalled() + expect(mockTask.recordToolUsage).not.toHaveBeenCalled() + }) + }) + + describe("search execution", () => { + it("should perform search and return results when approved", async () => { + const block: ToolUse = { + type: "tool_use", + name: "web_search", + params: { + query: "test search query", + }, + partial: false, + } + + await webSearchTool.handle(mockTask as Task, block as ToolUse<"web_search">, { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + removeClosingTag: mockRemoveClosingTag, + toolProtocol: "xml", + }) + + // Verify search was logged (i18n key format) + expect(mockTask.say).toHaveBeenCalledWith("text", "webSearch.searching") + + // Verify tool usage was recorded + expect(mockTask.recordToolUsage).toHaveBeenCalledWith("web_search") + + // Verify results were pushed (i18n key format in tests) + expect(mockPushToolResult).toHaveBeenCalled() + const resultCall = mockPushToolResult.mock.calls[0][0] + expect(resultCall).toContain("webSearch.results") + }) + + it("should reset consecutive mistake count on successful execution", async () => { + mockTask.consecutiveMistakeCount = 3 + + const block: ToolUse = { + type: "tool_use", + name: "web_search", + params: { + query: "test search query", + }, + partial: false, + } + + await webSearchTool.handle(mockTask as Task, block as ToolUse<"web_search">, { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + removeClosingTag: mockRemoveClosingTag, + toolProtocol: "xml", + }) + + expect(mockTask.consecutiveMistakeCount).toBe(0) + }) + + it("should include mock note in results", async () => { + const block: ToolUse = { + type: "tool_use", + name: "web_search", + params: { + query: "test search query", + }, + partial: false, + } + + await webSearchTool.handle(mockTask as Task, block as ToolUse<"web_search">, { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + removeClosingTag: mockRemoveClosingTag, + toolProtocol: "xml", + }) + + const resultCall = mockPushToolResult.mock.calls[0][0] + // Check for i18n key format + expect(resultCall).toContain("webSearch.mockNote") + }) + }) + + describe("error handling", () => { + it("should handle errors during search", async () => { + const testError = new Error("Search failed") + mockTask.say.mockRejectedValueOnce(testError) + + const block: ToolUse = { + type: "tool_use", + name: "web_search", + params: { + query: "test search query", + }, + partial: false, + } + + await webSearchTool.handle(mockTask as Task, block as ToolUse<"web_search">, { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + removeClosingTag: mockRemoveClosingTag, + toolProtocol: "xml", + }) + + expect(mockHandleError).toHaveBeenCalledWith("performing web search", testError) + expect(mockTask.recordToolError).toHaveBeenCalledWith("web_search") + }) + }) + + describe("removeClosingTag integration", () => { + it("should use removeClosingTag to clean query parameter", async () => { + const block: ToolUse = { + type: "tool_use", + name: "web_search", + params: { + query: "test query with tags", + }, + partial: false, + } + + mockRemoveClosingTag.mockImplementation((tag: string, content?: string) => { + if (tag === "query") { + return "cleaned query" + } + return content || "" + }) + + await webSearchTool.handle(mockTask as Task, block as ToolUse<"web_search">, { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + removeClosingTag: mockRemoveClosingTag, + toolProtocol: "xml", + }) + + expect(mockRemoveClosingTag).toHaveBeenCalledWith("query", "test query with tags") + expect(mockAskApproval).toHaveBeenCalledWith( + "tool", + JSON.stringify({ + tool: "webSearch", + query: "cleaned query", + }), + ) + }) + }) +}) diff --git a/src/i18n/locales/en/tools.json b/src/i18n/locales/en/tools.json index 94e1820249b..3e13455ce1f 100644 --- a/src/i18n/locales/en/tools.json +++ b/src/i18n/locales/en/tools.json @@ -27,5 +27,10 @@ "roo": { "authRequired": "Roo Code Cloud authentication is required for image generation. Please sign in to Roo Code Cloud." } + }, + "webSearch": { + "searching": "Searching the web for: \"{{query}}\"", + "results": "Web search results for \"{{query}}\":", + "mockNote": "Note: This is a demonstration implementation. In production, this would integrate with a real search API like Brave Search, Google Custom Search, Bing Search API, or DuckDuckGo API. You can also use the Perplexity MCP server for real web search capabilities." } } diff --git a/src/shared/tools.ts b/src/shared/tools.ts index f893a3d332e..913fc31b776 100644 --- a/src/shared/tools.ts +++ b/src/shared/tools.ts @@ -109,6 +109,7 @@ export type NativeToolArgs = { switch_mode: { mode_slug: string; reason: string } update_todo_list: { todos: string } use_mcp_tool: { server_name: string; tool_name: string; arguments?: Record } + web_search: { query: string } write_to_file: { path: string; content: string } // Add more tools as they are migrated to native protocol } @@ -236,6 +237,11 @@ export interface GenerateImageToolUse extends ToolUse<"generate_image"> { params: Partial, "prompt" | "path" | "image">> } +export interface WebSearchToolUse extends ToolUse<"web_search"> { + name: "web_search" + params: Partial, "query">> +} + // Define tool group configuration export type ToolGroupConfig = { tools: readonly string[] @@ -266,13 +272,14 @@ export const TOOL_DISPLAY_NAMES: Record = { update_todo_list: "update todo list", run_slash_command: "run slash command", generate_image: "generate images", + web_search: "search the web", custom_tool: "use custom tools", } as const // Define available tool groups. export const TOOL_GROUPS: Record = { read: { - tools: ["read_file", "fetch_instructions", "search_files", "list_files", "codebase_search"], + tools: ["read_file", "fetch_instructions", "search_files", "list_files", "codebase_search", "web_search"], }, edit: { tools: ["apply_diff", "write_to_file", "generate_image"],