diff --git a/package.json b/package.json index c18f5642..94c28260 100644 --- a/package.json +++ b/package.json @@ -11,7 +11,6 @@ "test": "jest", "test:coverage": "jest --coverage", "test:watch": "jest --watch", - "generate-random-key": "tsx src/cli/generate-aes-key.ts", "staged": "npm run typecheck && npm run lint && jest", "typecheck": "tsc --noEmit --skipLibCheck", "format": "prettier --check .", diff --git a/src/app/(outerbase)/local-setting-dialog.tsx b/src/app/(outerbase)/local-setting-dialog.tsx index 8d896681..bb728824 100644 --- a/src/app/(outerbase)/local-setting-dialog.tsx +++ b/src/app/(outerbase)/local-setting-dialog.tsx @@ -26,8 +26,6 @@ export const localSettingDialog = createDialog(({ close }) => { }, []); const onSaveClicked = useCallback(() => { - if (!token) return; - updateAgentFromLocalStorage({ provider: "openai", model: "gpt-4o-mini", diff --git a/src/app/(theme)/client/s/[[...driver]]/page-client.tsx b/src/app/(theme)/client/s/[[...driver]]/page-client.tsx index a1f6be46..56dec655 100644 --- a/src/app/(theme)/client/s/[[...driver]]/page-client.tsx +++ b/src/app/(theme)/client/s/[[...driver]]/page-client.tsx @@ -13,7 +13,7 @@ import { } from "@/core/standard-extension"; import { createLocalDriver } from "@/drivers/helpers"; import IndexdbSavedDoc from "@/drivers/saved-doc/indexdb-saved-doc"; -import { useAgentFromLocalStorage } from "@/lib/ai-agent-storage"; +import { useAvailableAIAgents } from "@/lib/ai-agent-storage"; import { useRouter, useSearchParams } from "next/navigation"; import { useCallback, useEffect, useMemo } from "react"; @@ -57,7 +57,7 @@ export default function ClientPageBody() { return new StudioExtensionManager(createStandardExtensions()); }, [driver]); - const agentDriver = useAgentFromLocalStorage(driver); + const agentDriver = useAvailableAIAgents(driver); const docDriver = useMemo(() => { if (conn) { diff --git a/src/app/(theme)/connect/saved-connection-storage.ts b/src/app/(theme)/connect/saved-connection-storage.ts index 4aeee182..c45960bb 100644 --- a/src/app/(theme)/connect/saved-connection-storage.ts +++ b/src/app/(theme)/connect/saved-connection-storage.ts @@ -1,5 +1,3 @@ -import { ApiUser } from "@/lib/api/api-database-response"; - export type SupportedDriver = | "turso" | "rqlite" @@ -27,9 +25,6 @@ export interface SavedConnectionItem { name: string; description?: string; label?: SavedConnectionLabel; - shared?: { - sharedBy: ApiUser; - }; } export interface SavedConnectionItemConfigConfig { diff --git a/src/app/(theme)/embed/[driver]/page-client.tsx b/src/app/(theme)/embed/[driver]/page-client.tsx index cfbe9832..88fd58b7 100644 --- a/src/app/(theme)/embed/[driver]/page-client.tsx +++ b/src/app/(theme)/embed/[driver]/page-client.tsx @@ -14,7 +14,7 @@ import { import ElectronSavedDocs from "@/drivers/saved-doc/electron-saved-doc"; import DoltExtension from "@/extensions/dolt"; import LocalSettingSidebar from "@/extensions/local-setting-sidebar"; -import { useAgentFromLocalStorage } from "@/lib/ai-agent-storage"; +import { useAvailableAIAgents } from "@/lib/ai-agent-storage"; import { useSearchParams } from "next/navigation"; import { useEffect, useMemo } from "react"; @@ -39,7 +39,7 @@ export default function EmbedPageClient({ return new StudioExtensionManager(createEmbedExtensions(driverName)); }, [driverName]); - const agentDriver = useAgentFromLocalStorage(driver); + const agentDriver = useAvailableAIAgents(driver); useEffect(() => { return driver.listen(); diff --git a/src/app/(theme)/playground/client/page-client.tsx b/src/app/(theme)/playground/client/page-client.tsx index e91d366c..6e2c40c2 100644 --- a/src/app/(theme)/playground/client/page-client.tsx +++ b/src/app/(theme)/playground/client/page-client.tsx @@ -10,7 +10,7 @@ import { StudioExtensionManager } from "@/core/extension-manager"; import { createSQLiteExtensions } from "@/core/standard-extension"; import SqljsDriver from "@/drivers/sqljs-driver"; import { localDb } from "@/indexdb"; -import { useAgentFromLocalStorage } from "@/lib/ai-agent-storage"; +import { useAvailableAIAgents } from "@/lib/ai-agent-storage"; import downloadFileFromUrl from "@/lib/download-file"; import { saveAs } from "file-saver"; import { @@ -44,7 +44,7 @@ export default function PlaygroundEditorBody({ const [handler, setHandler] = useState(); const [fileName, setFilename] = useState(""); - const agentDriver = useAgentFromLocalStorage(driver); + const agentDriver = useAvailableAIAgents(driver); /** * Initialize the SQL.js library. diff --git a/src/lib/api/insert-tracking-record.ts b/src/app/api/events/insert-tracking-record.ts similarity index 94% rename from src/lib/api/insert-tracking-record.ts rename to src/app/api/events/insert-tracking-record.ts index 9beea42f..65b3c1fa 100644 --- a/src/lib/api/insert-tracking-record.ts +++ b/src/app/api/events/insert-tracking-record.ts @@ -1,8 +1,8 @@ "use server"; import StarbaseDriver from "@/drivers/starbase-driver"; -import { type TrackEventItem } from "../tracking"; import { env } from "@/env"; +import { type TrackEventItem } from "../../../lib/tracking"; export async function insertTrackingRecord( deviceId: string, diff --git a/src/app/api/events/route.ts b/src/app/api/events/route.ts index 55e444f7..6a201c31 100644 --- a/src/app/api/events/route.ts +++ b/src/app/api/events/route.ts @@ -5,9 +5,9 @@ // All recorded data will be stored in the Starbase Database. import { headers } from "next/headers"; -import zod from "zod"; import { after, NextRequest, NextResponse } from "next/server"; -import { insertTrackingRecord } from "@/lib/api/insert-tracking-record"; +import zod from "zod"; +import { insertTrackingRecord } from "./insert-tracking-record"; const eventBodySchema = zod.object({ events: zod diff --git a/src/app/storybook/editor/page.tsx b/src/app/storybook/editor/page.tsx index cae91a74..0dc9ee16 100644 --- a/src/app/storybook/editor/page.tsx +++ b/src/app/storybook/editor/page.tsx @@ -28,7 +28,6 @@ END; dialect="sqlite" value={value} onChange={setValue} - enablePrompt onPrompt={async (prompt, selected) => { const selectedText = selected?.text; diff --git a/src/cli/generate-aes-key.ts b/src/cli/generate-aes-key.ts deleted file mode 100644 index 1bbf5353..00000000 --- a/src/cli/generate-aes-key.ts +++ /dev/null @@ -1,3 +0,0 @@ -import crypto from "crypto"; - -console.log("Generate Random Key:", crypto.randomBytes(32).toString("base64")); diff --git a/src/components/editor/prompt-plugin.tsx b/src/components/editor/prompt-plugin.tsx index a3bbc338..d6079b18 100644 --- a/src/components/editor/prompt-plugin.tsx +++ b/src/components/editor/prompt-plugin.tsx @@ -1,3 +1,4 @@ +import AgentDriverList from "@/drivers/agent/list"; import { unifiedMergeView } from "@codemirror/merge"; import { Compartment, @@ -38,6 +39,7 @@ export interface PromptSelectedFragment { startLineNumber: number; endLineNumber: number; sessionId: string; + selectedModel?: string; } export type PromptCallback = ( @@ -147,7 +149,7 @@ class PromptWidget extends WidgetType { // If the result of a query arrives and the counter has changed, the result is ignored. let queryCounter = 1; - const onGenerate = async (promptText: string) => { + const onGenerate = async (promptText: string, selectedModel?: string) => { try { const expectedQueryCounter = queryCounter; @@ -159,6 +161,7 @@ class PromptWidget extends WidgetType { endLineNumber: view.state.doc.lineAt( startPosition + suggestedText.length ).number, + selectedModel, }); // This prevent when we close the widget before @@ -240,6 +243,7 @@ class PromptWidget extends WidgetType { onSubmit={onGenerate} onReject={onReject} onCancel={onCancel} + agentDriver={plugin.agents} /> ); } @@ -325,6 +329,7 @@ export class CodeMirrorPromptPlugin { public isActive = false; activeWidget?: PromptWidget; promptCallback?: PromptCallback; + public agents?: AgentDriverList; /** * This is for locking when prompt is open. This is to prevent diff --git a/src/components/editor/prompt-widget.tsx b/src/components/editor/prompt-widget.tsx index d4baa0c0..1d3be257 100644 --- a/src/components/editor/prompt-widget.tsx +++ b/src/components/editor/prompt-widget.tsx @@ -1,16 +1,27 @@ "use client"; +import AgentDriverList from "@/drivers/agent/list"; import { cn } from "@/lib/utils"; -import { X } from "@phosphor-icons/react"; -import { useCallback, useRef, useState } from "react"; +import { Check, X } from "@phosphor-icons/react"; +import { useCallback, useMemo, useRef, useState } from "react"; +import { CloudflareIcon } from "../icons/outerbase-icon"; import { Button } from "../orbit/button"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuGroup, + DropdownMenuItem, + DropdownMenuLabel, + DropdownMenuTrigger, +} from "../ui/dropdown-menu"; interface CodeMirrorPromptWidgetProps { + agentDriver?: AgentDriverList; onClose?: () => void; onCancel?: () => void; onAccept?: () => void; onReject?: () => void; - onSubmit?: (query: string) => Promise; + onSubmit?: (query: string, selectedModel?: string) => Promise; } export function CodeMirrorPromptWidget({ @@ -19,6 +30,7 @@ export function CodeMirrorPromptWidget({ onSubmit, onAccept, onReject, + agentDriver, }: CodeMirrorPromptWidgetProps) { const textareaClassName = "absolute left-0 right-0 resize-none p-1 p-2 outline-none"; @@ -26,6 +38,14 @@ export function CodeMirrorPromptWidget({ const fakeTextareaRef = useRef(null); const textareaRef = useRef(null); + const agentList = useMemo(() => { + if (!agentDriver) return []; + return agentDriver.list(); + }, [agentDriver]); + const [selectedAgent, setSelectedAgent] = useState(() => + agentDriver?.getDefaultModelName() + ); + const [previousPrompt, setPreviousPrompt] = useState(""); const [error, setError] = useState(""); const [loading, setLoading] = useState(false); @@ -58,7 +78,7 @@ export function CodeMirrorPromptWidget({ setError(""); cancelTriggered.current = false; - onSubmit(prompt) + onSubmit(prompt, selectedAgent) .then(() => { if (!cancelTriggered.current) { setPreviousPrompt(prompt); @@ -77,7 +97,7 @@ export function CodeMirrorPromptWidget({ setLoading(false); }); } - }, [onSubmit, prompt]); + }, [onSubmit, prompt, selectedAgent]); const triggerReject = useCallback(() => { if (onReject) onReject(); @@ -162,7 +182,11 @@ export function CodeMirrorPromptWidget({ - {error &&
{error}
} + {error && ( +
+ {error.split("\n")[0]} +
+ )}
{(showSubmitButton || showSubmitEditButton) && ( @@ -194,6 +218,54 @@ export function CodeMirrorPromptWidget({ )} +
+ + {agentList.length > 0 && ( + + + + + + + {agentList.map((group) => ( + + {group.title} + {group.agents.map((agent) => ( + { + setSelectedAgent(agent.name); + if (agentDriver) { + agentDriver.setDefaultModelName(agent.name); + } + }} + > + {agent.name === selectedAgent ? ( + + ) : null} + {agent.name} + {agent.free ? ( +
+ + free tier + +
+ ) : null} +
+ ))} +
+ ))} +
+
+ )} + {!previousPrompt && !loading ? ( Esc to close diff --git a/src/components/gui/query-explanation-diagram/index.tsx b/src/components/gui/query-explanation-diagram/index.tsx index 97491968..c7dc1d8a 100644 --- a/src/components/gui/query-explanation-diagram/index.tsx +++ b/src/components/gui/query-explanation-diagram/index.tsx @@ -17,12 +17,14 @@ import { NestedLoop } from "./node-type/nested-loop"; import { TableBlock } from "./node-type/table-block"; import { OperationBlock } from "./node-type/operation-block"; import { UnionBlock } from "./node-type/union-block"; +import { useTheme } from "next-themes"; interface LayoutFlowProps { items: ExplanationMysql; } function QueryExplanationFlow(props: LayoutFlowProps) { + const { forcedTheme, resolvedTheme } = useTheme(); const [loading, setLoading] = useState(true); const [nodes, setNodes, onNodesChange] = useNodesState([]); const [edges, setEdges, onEdgesChange] = useEdgesState([]); @@ -56,8 +58,11 @@ function QueryExplanationFlow(props: LayoutFlowProps) { } }, [props, loading, setEdges, setNodes]); + const appTheme = (forcedTheme ?? resolvedTheme) as "dark" | "light"; + return ( ( onCursorChange, readOnly, fontSize, + agents, onFontSizeChanged, variableList, highlightVariable, - - enablePrompt, onPrompt, }: SqlEditorProps, ref @@ -94,14 +94,15 @@ const SqlEditor = forwardRef( }, [schema]); const promptPlugin = useMemo(() => { - return enablePrompt ? new CodeMirrorPromptPlugin() : null; - }, [enablePrompt]); + return agents ? new CodeMirrorPromptPlugin() : null; + }, [agents]); useEffect(() => { if (promptPlugin && onPrompt) { promptPlugin.handleSuggestion(onPrompt); + promptPlugin.agents = agents; } - }, [promptPlugin, onPrompt]); + }, [promptPlugin, onPrompt, agents]); const keyExtensions = useMemo(() => { return keymap.of([ diff --git a/src/components/gui/studio.tsx b/src/components/gui/studio.tsx index b181395a..84d3bf9b 100644 --- a/src/components/gui/studio.tsx +++ b/src/components/gui/studio.tsx @@ -4,9 +4,8 @@ import { ConfigProvider } from "@/context/config-provider"; import { DriverProvider } from "@/context/driver-provider"; import { StudioExtensionManager } from "@/core/extension-manager"; import { BeforeQueryPipeline } from "@/core/query-pipeline"; -import { AgentBaseDriver } from "@/drivers/agent/base"; +import AgentDriverList from "@/drivers/agent/list"; import type { BaseDriver } from "@/drivers/base-driver"; -import { CollaborationBaseDriver } from "@/drivers/collaboration-driver-base"; import { SavedDocDriver } from "@/drivers/saved-doc/saved-doc-driver"; import { useEffect, useMemo, useRef } from "react"; import { CommonDialogProvider } from "../common-dialog"; @@ -15,8 +14,7 @@ import { FullEditorProvider } from "./providers/full-editor-provider"; interface StudioProps { driver: BaseDriver; extensions?: StudioExtensionManager; - collaboration?: CollaborationBaseDriver; - agentDriver?: AgentBaseDriver; + agentDriver?: AgentDriverList; docDriver?: SavedDocDriver; name: string; color: string; @@ -27,7 +25,6 @@ interface StudioProps { export function Studio({ driver, - collaboration, name, color, onBack, @@ -107,11 +104,7 @@ export function Studio({ ]); return ( - + diff --git a/src/components/gui/tabs/query-tab.tsx b/src/components/gui/tabs/query-tab.tsx index 6f2ecd61..02655982 100644 --- a/src/components/gui/tabs/query-tab.tsx +++ b/src/components/gui/tabs/query-tab.tsx @@ -322,22 +322,21 @@ export default function QueryWindow({ [] ); - const promptConversationIds = useRef>({}); const onPrompt = useCallback( async (promptQuery: string, option: PromptSelectedFragment) => { if (!agentDriver) return ""; - const agentResponse = await agentDriver.promptInline( + const agentResponse = await agentDriver.run( + option.selectedModel ?? "gemma-7b-it", promptQuery, - promptConversationIds.current[option.sessionId], + option.sessionId, { selected: option?.text ?? "", schema: schema, } ); - promptConversationIds.current[option.sessionId] = agentResponse.id; - return agentResponse.result; + return agentResponse; }, [agentDriver, schema] ); @@ -419,8 +418,8 @@ export default function QueryWindow({
x.position.y)) ?? 0; + : (Math.min(...layoutRelationship.nodes.map((x) => x.position.y)) ?? 0); // Calculate estimate area of the nodes without relationship const area = @@ -279,6 +280,7 @@ function mapSchema( } function LayoutFlow() { + const { forcedTheme, resolvedTheme } = useTheme(); const [nodes, setNodes, onNodesChange] = useNodesState([]); const [edges, setEdges, onEdgesChange] = useEdgesState([]); const { schema: initialSchema, currentSchemaName, refresh } = useSchema(); @@ -297,6 +299,8 @@ function LayoutFlow() { databaseSchema: DatabaseSchemaNode, }; + const appTheme = (forcedTheme ?? resolvedTheme) as "dark" | "light"; + return (
@@ -368,6 +372,7 @@ function LayoutFlow() { {selectedSchema && (
- + diff --git a/src/context/config-provider.tsx b/src/context/config-provider.tsx index 575dc4dd..89fa96fd 100644 --- a/src/context/config-provider.tsx +++ b/src/context/config-provider.tsx @@ -1,5 +1,5 @@ import { StudioExtensionManager } from "@/core/extension-manager"; -import { AgentBaseDriver } from "@/drivers/agent/base"; +import AgentDriverList from "@/drivers/agent/list"; import { noop } from "lodash"; import type { PropsWithChildren } from "react"; import { createContext, useContext } from "react"; @@ -10,7 +10,7 @@ interface ConfigContextProps { onBack?: () => void; extensions: StudioExtensionManager; containerClassName?: string; - agentDriver?: AgentBaseDriver; + agentDriver?: AgentDriverList; } const ConfigContext = createContext({ diff --git a/src/context/driver-provider.tsx b/src/context/driver-provider.tsx index b0e09448..9e380019 100644 --- a/src/context/driver-provider.tsx +++ b/src/context/driver-provider.tsx @@ -1,11 +1,9 @@ import type { BaseDriver } from "@/drivers/base-driver"; -import { CollaborationBaseDriver } from "@/drivers/collaboration-driver-base"; import { SavedDocDriver } from "@/drivers/saved-doc/saved-doc-driver"; import { type PropsWithChildren, createContext, useContext } from "react"; const DriverContext = createContext<{ databaseDriver: BaseDriver; - collaborationDriver?: CollaborationBaseDriver; docDriver?: SavedDocDriver; }>({ databaseDriver: {} as unknown as BaseDriver, @@ -18,17 +16,13 @@ export function useDatabaseDriver() { export function DriverProvider({ children, driver, - collaborationDriver, docDriver, }: PropsWithChildren<{ driver: BaseDriver; - collaborationDriver?: CollaborationBaseDriver; docDriver?: SavedDocDriver; }>) { return ( - + {children} ); diff --git a/src/drivers/agent/base.ts b/src/drivers/agent/base.ts index 379d3b89..052a4065 100644 --- a/src/drivers/agent/base.ts +++ b/src/drivers/agent/base.ts @@ -18,9 +18,9 @@ export abstract class AgentBaseDriver { * @param previousId Previous message id. If not provided, it is a new conversation * @param option */ - abstract promptInline( + abstract run( message: string, previousId: string | undefined, option: AgentPromptOption - ): Promise; + ): Promise; } diff --git a/src/drivers/agent/chatgpt.ts b/src/drivers/agent/chatgpt.ts index af9c8118..475118cf 100644 --- a/src/drivers/agent/chatgpt.ts +++ b/src/drivers/agent/chatgpt.ts @@ -1,155 +1,18 @@ -import { - BaseDriver, - DatabaseSchemas, - DatabaseTableSchema, -} from "../base-driver"; -import { - AgentBaseDriver, - AgentPromptOption, - AgentPromptResponse, -} from "./base"; - -interface ChatHistory { - id: string; - createdAt: number; - messages: { role: string; content: string }[]; -} - +import { BaseDriver } from "../base-driver"; +import CommonAgentDriverImplementation, { CommonAgentMessage } from "./common"; interface ChatGPTResponse { choices: { message: { role: string; content: string } }[]; } -export class ChatGPTDriver implements AgentBaseDriver { - protected history: Record = {}; - +export class ChatGPTDriver extends CommonAgentDriverImplementation { constructor( protected driver: BaseDriver, protected token: string - ) {} - - protected convertTableToContent( - schemaName: string | undefined, - table: DatabaseTableSchema - ): string { - const columns = table.columns - .map((column) => { - return `${this.driver.escapeId(column.name)} ${column.type}`; - }) - .join(",\n"); - - const fullTableName = schemaName - ? `${this.driver.escapeId(schemaName)}.${this.driver.escapeId(table.tableName ?? "")}` - : this.driver.escapeId(table.tableName ?? ""); - - const primaryKeyPart = - table.pk.length > 0 - ? `, PRIMARY KEY (${table.pk.map(this.driver.escapeId).join(", ")})` - : ""; - - const foreignKeyPart: string[] = []; - for (const column of table.columns) { - if (column.constraint?.foreignKey) { - foreignKeyPart.push( - [ - "FOREIGN KEY", - column.name, - "REFERENCES", - column.constraint.foreignKey.foreignTableName ?? "", - "(", - (column.constraint?.foreignKey?.foreignColumns ?? [])[0] ?? "", - ")", - ].join(" ") - ); - } - } - - for (const constraint of table.constraints ?? []) { - if (constraint.foreignKey) { - foreignKeyPart.push( - [ - "FOREIGN KEY", - `(${(constraint.foreignKey.columns ?? []).join(", ")})`, - "REFERENCES", - constraint.foreignKey.foreignTableName ?? "", - `(${(constraint.foreignKey.foreignColumns ?? []).join(", ")})`, - ].join(" ") - ); - } - } - - return `CREATE TABLE ${fullTableName} (\n${columns}\n ${primaryKeyPart});`; + ) { + super(driver); } - protected convertSchemaToContent(schemas: DatabaseSchemas): string { - const schemaParts: string[] = []; - const defaultSchema = this.driver.getFlags().defaultSchema; - - for (const [schemaName, schema] of Object.entries(schemas)) { - for (const table of schema) { - if (!table.tableSchema) continue; - if (!["table", "view"].includes(table.type)) continue; - - schemaParts.push( - this.convertTableToContent( - defaultSchema.toLowerCase() === schemaName.toLowerCase() - ? "" - : schemaName, - table.tableSchema - ) - ); - } - } - - return schemaParts.join("\n\n"); - } - - async promptInline( - message: string, - previousId: string | undefined, - option: AgentPromptOption - ): Promise { - const session = this.history[previousId ?? ""] ?? { - id: crypto.randomUUID(), - createdAt: Date.now(), - messages: [], - }; - - if (session.messages.length === 0) { - session.messages.push({ - role: "system", - content: "You are an SQL expert. Only return SQL code.", - }); - - session.messages.push({ - role: "user", - content: - "Here is " + - this.driver.getFlags().dialect + - " my database schema:\n\n", - }); - - if (option.schema) { - session.messages.push({ - role: "user", - content: - "```sql\n" + this.convertSchemaToContent(option.schema) + "```", - }); - } - - if (option.selected) { - session.messages.push({ - role: "user", - content: - "This is my selected query ```sql\n" + option.selected + "```", - }); - } - } - - session.messages.push({ - role: "user", - content: message, - }); - + async query(messages: CommonAgentMessage[]): Promise { const response = await fetch("https://api.openai.com/v1/chat/completions", { method: "POST", headers: { @@ -157,29 +20,13 @@ export class ChatGPTDriver implements AgentBaseDriver { "Content-Type": "application/json", }, body: JSON.stringify({ - model: "gpt-4o-mini-2024-07-18", + model: "gpt-4o-mini", temperature: 0, - messages: session.messages, + messages: messages, }), }); const jsonResponse = (await response.json()) as ChatGPTResponse; - const suggestedQuery = jsonResponse.choices[0].message.content; - - // Striped the SQL code from the response - const sqlCode = suggestedQuery.replace(/```sql\n/g, "").replace(/```/g, ""); - - // Save the chat history - session.messages.push({ - role: "assistant", - content: suggestedQuery, - }); - - this.history[session.id] = session; - - return { - id: session.id, - result: sqlCode, - }; + return jsonResponse.choices[0].message.content; } } diff --git a/src/drivers/agent/cloudflare.ts b/src/drivers/agent/cloudflare.ts new file mode 100644 index 00000000..8821677f --- /dev/null +++ b/src/drivers/agent/cloudflare.ts @@ -0,0 +1,43 @@ +import { format } from "sql-formatter"; +import { BaseDriver } from "../base-driver"; +import CommonAgentDriverImplementation, { CommonAgentMessage } from "./common"; + +export default class CloudflareAgentDriver extends CommonAgentDriverImplementation { + constructor( + protected driver: BaseDriver, + protected model: + | "@cf/defog/sqlcoder-7b-2" + | "@cf/meta/llama-3.3-70b-instruct-fp8-fast" + ) { + super(driver); + } + + async query(messages: CommonAgentMessage[]): Promise { + const response = await fetch( + "https://studio-ai-agent.outerbase.workers.dev/chat/complete", + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + model: this.model, + messages: messages, + }), + } + ); + + const jsonResponse = (await response.json()) as { response: string }; + return jsonResponse.response; + } + + processResult(result: string): string { + if (this.model === "@cf/defog/sqlcoder-7b-2") { + return format(result, { + language: "sqlite", + }); + } + + return super.processResult(result); + } +} diff --git a/src/drivers/agent/common.ts b/src/drivers/agent/common.ts new file mode 100644 index 00000000..f97b0f86 --- /dev/null +++ b/src/drivers/agent/common.ts @@ -0,0 +1,185 @@ +import { + BaseDriver, + DatabaseSchemas, + DatabaseTableSchema, +} from "../base-driver"; +import { AgentBaseDriver, AgentPromptOption } from "./base"; + +export interface ChatHistory { + id: string; + createdAt: number; + messages: { role: string; content: string }[]; +} + +export interface CommonAgentMessage { + role: string; + content: string; +} + +export default abstract class CommonAgentDriverImplementation extends AgentBaseDriver { + protected history: Record = {}; + + abstract query(messages: CommonAgentMessage[]): Promise; + + getSystemContent(option: AgentPromptOption): string { + if (option.selected) { + return `You are an SQL expert. User is using ${this.driver.getFlags().dialect}. You are given a user selected query and you will improve it. Only return SQL code`; + } + + return `You are an SQL expert. User is using ${this.driver.getFlags().dialect}.Only return SQL code`; + } + + getSchemaContent(option: AgentPromptOption) { + const parts = []; + + if (option.schema) { + parts.push( + "Here is " + this.driver.getFlags().dialect + " my database schema:\n\n" + ); + + parts.push( + "```sql\n" + this.convertSchemaToDDLContent(option.schema) + "```" + ); + } + + return parts.join("\n"); + } + + processResult(result: string): string { + // Find the code block and extract it + const codeBlock = result.match(/```sql\n([\s\S]*?)```/); + if (codeBlock) { + return codeBlock[1]; + } + + throw new Error("We cannot generate good response"); + } + + async run( + message: string, + previousId: string | undefined, + option: AgentPromptOption + ): Promise { + const session = this.history[previousId ?? ""] ?? { + id: previousId || crypto.randomUUID(), + createdAt: Date.now(), + messages: [], + }; + + if (session.messages.length === 0) { + session.messages.push({ + role: "system", + content: this.getSystemContent(option), + }); + + session.messages.push({ + role: "user", + content: this.getSchemaContent(option), + }); + + if (option.selected) { + session.messages.push({ + role: "user", + content: + "This is my selected query ```sql\n" + option.selected + "```", + }); + } + } + + session.messages.push({ + role: "user", + content: message, + }); + + const result = await this.query(session.messages); + + // Save the chat history + session.messages.push({ + role: "assistant", + content: result, + }); + + this.history[session.id] = session; + return this.processResult(result); + } + + constructor(protected driver: BaseDriver) { + super(); + } + + protected convertTableToDDLContent( + schemaName: string | undefined, + table: DatabaseTableSchema + ): string { + const columns = table.columns + .map((column) => { + return `${this.driver.escapeId(column.name)} ${column.type}`; + }) + .join(",\n"); + + const fullTableName = schemaName + ? `${this.driver.escapeId(schemaName)}.${this.driver.escapeId(table.tableName ?? "")}` + : this.driver.escapeId(table.tableName ?? ""); + + const primaryKeyPart = + table.pk.length > 0 + ? `, PRIMARY KEY (${table.pk.map(this.driver.escapeId).join(", ")})` + : ""; + + const foreignKeyPart: string[] = []; + for (const column of table.columns) { + if (column.constraint?.foreignKey) { + foreignKeyPart.push( + [ + "FOREIGN KEY", + column.name, + "REFERENCES", + column.constraint.foreignKey.foreignTableName ?? "", + "(", + (column.constraint?.foreignKey?.foreignColumns ?? [])[0] ?? "", + ")", + ].join(" ") + ); + } + } + + for (const constraint of table.constraints ?? []) { + if (constraint.foreignKey) { + foreignKeyPart.push( + [ + "FOREIGN KEY", + `(${(constraint.foreignKey.columns ?? []).join(", ")})`, + "REFERENCES", + constraint.foreignKey.foreignTableName ?? "", + `(${(constraint.foreignKey.foreignColumns ?? []).join(", ")})`, + ].join(" ") + ); + } + } + + return `CREATE TABLE ${fullTableName} (\n${columns}\n ${primaryKeyPart});`; + } + + protected convertSchemaToDDLContent(schemas: DatabaseSchemas): string { + const schemaParts: string[] = []; + const defaultSchema = this.driver.getFlags().defaultSchema; + + for (const [schemaName, schema] of Object.entries(schemas)) { + for (const table of schema) { + if (!table.tableSchema) continue; + if (!["table", "view"].includes(table.type)) continue; + + schemaParts.push( + this.convertTableToDDLContent( + defaultSchema.toLowerCase() === schemaName.toLowerCase() + ? "" + : schemaName, + table.tableSchema + ) + ); + } + } + + return schemaParts.join("\n\n"); + } +} diff --git a/src/drivers/agent/list.tsx b/src/drivers/agent/list.tsx new file mode 100644 index 00000000..c243f79d --- /dev/null +++ b/src/drivers/agent/list.tsx @@ -0,0 +1,104 @@ +import { CloudflareIcon } from "@/components/icons/outerbase-icon"; +import { ReactElement } from "react"; +import { BaseDriver } from "../base-driver"; +import { AgentBaseDriver, AgentPromptOption } from "./base"; +import { ChatGPTDriver } from "./chatgpt"; +import CloudflareAgentDriver from "./cloudflare"; + +interface AgentDriverListItem { + name: string; + free?: boolean; + available: boolean; +} + +interface AgentDriverListGroup { + name: string; + title: ReactElement | string; + agents: AgentDriverListItem[]; +} + +const DEFAULT_FREE_TIER_MODEL = "llama-3.3-70b"; + +export default class AgentDriverList { + protected dict: Record = {}; + protected defaultModelName: string | undefined; + + constructor(databaseDriver: BaseDriver, token?: string) { + this.dict = { + "llama-3.3-70b": new CloudflareAgentDriver( + databaseDriver, + "@cf/meta/llama-3.3-70b-instruct-fp8-fast" + ), + + "sqlcoder-7b-2": new CloudflareAgentDriver( + databaseDriver, + "@cf/defog/sqlcoder-7b-2" + ), + + "gpt-4o mini": token + ? new ChatGPTDriver(databaseDriver, token) + : undefined, + }; + + this.defaultModelName = + localStorage.getItem("default-agent-model") ?? DEFAULT_FREE_TIER_MODEL; + } + + setDefaultModelName(name: string) { + this.defaultModelName = name; + localStorage.setItem("default-agent-model", name); + } + + getDefaultModelName(): string { + return this.defaultModelName || "gemma-2b-it-lora"; + } + + list(): AgentDriverListGroup[] { + return [ + { + name: "cloudflare", + title: ( +
+ Powered by{" "} + + Cloudflare Workers AI +
+ ), + agents: [ + { + name: "llama-3.3-70b", + free: true, + available: !!this.dict["llama-3.3-70b"], + }, + { + name: "sqlcoder-7b-2", + free: true, + available: !!this.dict["sqlcoder-7b-2"], + }, + ], + }, + { + name: "other", + title: "Bring your own model", + agents: [ + { name: "gpt-4o mini", available: !!this.dict["gpt-4o mini"] }, + ], + }, + ]; + } + + async run( + modelName: string, + message: string, + sessionId: string | undefined, + options: AgentPromptOption + ): Promise { + const driver = this.dict[modelName]; + + if (!driver) { + throw new Error(`Selected model ${modelName} is not available`); + } + + return await driver.run(message, sessionId, options); + } +} diff --git a/src/drivers/collaboration-driver-base.ts b/src/drivers/collaboration-driver-base.ts deleted file mode 100644 index fe8426ab..00000000 --- a/src/drivers/collaboration-driver-base.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { ApiRole, ApiUserRole } from "@/lib/api/api-database-response"; - -export abstract class CollaborationBaseDriver { - abstract getRoles(): Promise; - abstract getUsers(): Promise; - abstract assignUser(userId: string, roleId: string): Promise; - abstract deleteUser(userId: string): Promise; -} diff --git a/src/drivers/collaboration-driver.ts b/src/drivers/collaboration-driver.ts deleted file mode 100644 index ed94d5d4..00000000 --- a/src/drivers/collaboration-driver.ts +++ /dev/null @@ -1,50 +0,0 @@ -import { RequestDatabaseBody } from "@/lib/api/api-database-request"; -import { - ApiRole, - ApiRolesResponse, - ApiUserListResponse, - ApiUserRole, -} from "@/lib/api/api-database-response"; -import { CollaborationBaseDriver } from "./collaboration-driver-base"; - -export default class Collaborator implements CollaborationBaseDriver { - protected id: string = ""; - protected authToken = ""; - - constructor(id: string, authToken: string) { - this.id = id; - this.authToken = authToken; - } - - protected async request(body: RequestDatabaseBody) { - const r = await fetch(`/api/database/${this.id}/ops`, { - method: "POST", - headers: { - Authorization: "Bearer " + this.authToken, - "Content-Type": "application/json", - }, - body: JSON.stringify(body), - }); - - const json = await r.json(); - if (json?.error) throw new Error(json.error); - - return json as T; - } - - async getRoles(): Promise { - return (await this.request({ type: "roles" })).roles; - } - - async getUsers(): Promise { - return (await this.request({ type: "users" })).users; - } - - async assignUser(userId: string, roleId: string) { - await this.request({ type: "assign-user", roleId, userId }); - } - - async deleteUser(userId: string) { - await this.request({ type: "delete-user", userId }); - } -} diff --git a/src/drivers/remote-driver.ts b/src/drivers/remote-driver.ts deleted file mode 100644 index e70f2294..00000000 --- a/src/drivers/remote-driver.ts +++ /dev/null @@ -1,61 +0,0 @@ -import { - ApiOpsBatchResponse, - ApiOpsQueryResponse, -} from "@/lib/api-response-types"; -import { RequestOperationBody } from "@/lib/api/api-request-types"; -import { SqliteLikeBaseDriver } from "./sqlite-base-driver"; - -export default class RemoteDriver extends SqliteLikeBaseDriver { - protected id: string = ""; - protected authToken = ""; - protected type: "temporary" | "remote" = "remote"; - - constructor(type: "temporary" | "remote", id: string, authToken: string) { - super(); - this.id = id; - this.authToken = authToken; - this.type = type; - } - - close(): void { - // do nothing - } - protected async request(body: RequestOperationBody) { - const url = - this.type === "temporary" - ? `/api/temp_ops/${this.id}` - : `/api/ops/${this.id}`; - - const r = await fetch(url, { - method: "POST", - headers: { - Authorization: "Bearer " + this.authToken, - "Content-Type": "application/json", - }, - body: JSON.stringify(body), - }); - - const json = await r.json(); - if (json?.error) throw new Error(json.error); - - return json as T; - } - - async query(stmt: string) { - const r = await this.request({ - type: "query", - statement: stmt, - }); - - return r.data; - } - - async transaction(stmt: string[]) { - const r = await this.request({ - type: "batch", - statements: stmt, - }); - - return r.data; - } -} diff --git a/src/lib/ai-agent-storage.ts b/src/lib/ai-agent-storage.ts index cc3f391a..1d662c2d 100644 --- a/src/lib/ai-agent-storage.ts +++ b/src/lib/ai-agent-storage.ts @@ -1,4 +1,4 @@ -import { ChatGPTDriver } from "@/drivers/agent/chatgpt"; +import AgentDriverList from "@/drivers/agent/list"; import { BaseDriver } from "@/drivers/base-driver"; import { useMemo } from "react"; import useSWR, { mutate } from "swr"; @@ -32,7 +32,7 @@ export function updateAgentFromLocalStorage(data: LocalAgentType) { mutate("/local-agent-setting", data); } -export function useAgentFromLocalStorage(databaseDriver?: BaseDriver | null) { +export function useAvailableAIAgents(databaseDriver?: BaseDriver | null) { const { data: agentConfig } = useSWR( "/local-agent-setting", getAgentFromLocalStorage @@ -40,8 +40,6 @@ export function useAgentFromLocalStorage(databaseDriver?: BaseDriver | null) { return useMemo(() => { if (!databaseDriver) return undefined; - if (!agentConfig) return undefined; - - return new ChatGPTDriver(databaseDriver, agentConfig.token); + return new AgentDriverList(databaseDriver, agentConfig?.token); }, [databaseDriver, agentConfig]); } diff --git a/src/lib/api/api-database-request.ts b/src/lib/api/api-database-request.ts deleted file mode 100644 index f0983957..00000000 --- a/src/lib/api/api-database-request.ts +++ /dev/null @@ -1,24 +0,0 @@ -export interface RequestDatabaseRoleList { - type: "roles"; -} - -export interface RequestDatabaseUserList { - type: "users"; -} - -export interface RequestDatabaseAssignUser { - type: "assign-user"; - roleId: string; - userId: string; -} - -export interface RequestDatabaseDeleteUser { - type: "delete-user"; - userId: string; -} - -export type RequestDatabaseBody = - | RequestDatabaseRoleList - | RequestDatabaseUserList - | RequestDatabaseAssignUser - | RequestDatabaseDeleteUser; diff --git a/src/lib/api/api-database-response.ts b/src/lib/api/api-database-response.ts deleted file mode 100644 index 52d488cb..00000000 --- a/src/lib/api/api-database-response.ts +++ /dev/null @@ -1,25 +0,0 @@ -export interface ApiRole { - id: string; - name: string; -} - -export interface ApiUser { - id: string; - name: string; -} - -export interface ApiUserRole { - id: string; - name: string; - role: ApiRole; - createdAt: number; - assignedBy: ApiUser; -} - -export interface ApiRolesResponse { - roles: ApiRole[]; -} - -export interface ApiUserListResponse { - users: ApiUserRole[]; -} diff --git a/src/lib/api/api-request-types.ts b/src/lib/api/api-request-types.ts deleted file mode 100644 index 3f6d6c01..00000000 --- a/src/lib/api/api-request-types.ts +++ /dev/null @@ -1,13 +0,0 @@ -export interface RequestOperationBatch { - type: "batch"; - statements: string[]; -} - -export interface RequestOperationQuery { - type: "query"; - statement: string; -} - -export type RequestOperationBody = - | RequestOperationBatch - | RequestOperationQuery; diff --git a/src/lib/api/fetch-databases.ts b/src/lib/api/fetch-databases.ts deleted file mode 100644 index 0adb292d..00000000 --- a/src/lib/api/fetch-databases.ts +++ /dev/null @@ -1,54 +0,0 @@ -import { - SavedConnectionItem, - SavedConnectionItemConfig, - SupportedDriver, -} from "@/app/(theme)/connect/saved-connection-storage"; -import { ApiDatabasesResponse } from "../api-response-types"; - -export async function getDatabases(): Promise { - const result = await fetch("/api/databases"); - return await result.json(); -} - -export async function createDatabase( - data: SavedConnectionItemConfig & { driver: SupportedDriver } -): Promise<{ data: SavedConnectionItem }> { - const r = await fetch(`/api/database`, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(data), - }); - - return await r.json(); -} - -export async function updateDatabase( - id: string, - data: SavedConnectionItemConfig -) { - const r = await fetch(`/api/database/${id}`, { - method: "PUT", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(data), - }); - - return await r.json(); -} - -export async function getDatabase( - id: string -): Promise { - const result = await fetch(`/api/database/${id}`); - return await result.json(); -} - -export async function deleteDatabase( - id: string -): Promise { - const result = await fetch(`/api/database/${id}`, { method: "DELETE" }); - return await result.json(); -}