From 886fc35d3f94353066893c2b7a13ee771b08a957 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Thu, 19 Feb 2026 00:12:34 -0800 Subject: [PATCH] chore(rivetkit): make execute generic --- .../driver-test-suite/actor-db-drizzle.ts | 36 +++++++++---------- .../driver-test-suite/actor-db-raw.ts | 36 +++++++++---------- .../fixtures/driver-test-suite/workflow.ts | 16 ++++----- .../packages/rivetkit/src/db/config.ts | 11 ++++-- .../packages/rivetkit/src/db/drizzle/mod.ts | 33 +++++++++++++++-- .../packages/rivetkit/src/db/mod.ts | 20 ++++++++--- .../rivetkit/tests/actor-types.test.ts | 18 ++++++++++ 7 files changed, 115 insertions(+), 55 deletions(-) diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/actor-db-drizzle.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/actor-db-drizzle.ts index 1cd7f297e8..bf74dffeef 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/actor-db-drizzle.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/actor-db-drizzle.ts @@ -16,38 +16,38 @@ export const dbActorDrizzle = actor({ await c.db.execute( `INSERT INTO test_data (value, payload, created_at) VALUES ('${value}', '', ${Date.now()})`, ); - const results = (await c.db.execute( + const results = await c.db.execute<{ id: number }>( `SELECT last_insert_rowid() as id`, - )) as Array<{ id: number }>; + ); return { id: results[0].id }; }, getValues: async (c) => { - const results = (await c.db.execute( - `SELECT * FROM test_data ORDER BY id`, - )) as Array<{ + const results = await c.db.execute<{ id: number; value: string; payload: string; created_at: number; - }>; + }>( + `SELECT * FROM test_data ORDER BY id`, + ); return results; }, getValue: async (c, id: number) => { - const results = (await c.db.execute( + const results = await c.db.execute<{ value: string }>( `SELECT value FROM test_data WHERE id = ${id}`, - )) as Array<{ value: string }>; + ); return results[0]?.value ?? null; }, getCount: async (c) => { - const results = (await c.db.execute( + const results = await c.db.execute<{ count: number }>( `SELECT COUNT(*) as count FROM test_data`, - )) as Array<{ count: number }>; + ); return results[0].count; }, rawSelectCount: async (c) => { - const results = (await c.db.execute( + const results = await c.db.execute<{ count: number }>( `SELECT COUNT(*) as count FROM test_data`, - )) as Array<{ count: number }>; + ); return results[0]?.count ?? 0; }, insertMany: async (c, count: number) => { @@ -88,15 +88,15 @@ export const dbActorDrizzle = actor({ await c.db.execute( `INSERT INTO test_data (value, payload, created_at) VALUES ('payload', '${payload}', ${Date.now()})`, ); - const results = (await c.db.execute( + const results = await c.db.execute<{ id: number }>( `SELECT last_insert_rowid() as id`, - )) as Array<{ id: number }>; + ); return { id: results[0].id, size }; }, getPayloadSize: async (c, id: number) => { - const results = (await c.db.execute( + const results = await c.db.execute<{ size: number }>( `SELECT length(payload) as size FROM test_data WHERE id = ${id}`, - )) as Array<{ size: number }>; + ); return results[0]?.size ?? 0; }, repeatUpdate: async (c, id: number, count: number) => { @@ -119,9 +119,9 @@ export const dbActorDrizzle = actor({ await c.db.execute( `BEGIN; INSERT INTO test_data (value, payload, created_at) VALUES ('${value}', '', ${Date.now()}); UPDATE test_data SET value = '${value}-updated' WHERE id = last_insert_rowid(); COMMIT;`, ); - const results = (await c.db.execute( + const results = await c.db.execute<{ value: string }>( `SELECT value FROM test_data ORDER BY id DESC LIMIT 1`, - )) as Array<{ value: string }>; + ); return results[0]?.value ?? null; }, triggerSleep: (c) => { diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/actor-db-raw.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/actor-db-raw.ts index 89ed0d02c9..1eac26184e 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/actor-db-raw.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/actor-db-raw.ts @@ -22,38 +22,38 @@ export const dbActorRaw = actor({ await c.db.execute( `INSERT INTO test_data (value, payload, created_at) VALUES ('${value}', '', ${Date.now()})`, ); - const results = (await c.db.execute( + const results = await c.db.execute<{ id: number }>( `SELECT last_insert_rowid() as id`, - )) as Array<{ id: number }>; + ); return { id: results[0].id }; }, getValues: async (c) => { - const results = (await c.db.execute( - `SELECT * FROM test_data ORDER BY id`, - )) as Array<{ + const results = await c.db.execute<{ id: number; value: string; payload: string; created_at: number; - }>; + }>( + `SELECT * FROM test_data ORDER BY id`, + ); return results; }, getValue: async (c, id: number) => { - const results = (await c.db.execute( + const results = await c.db.execute<{ value: string }>( `SELECT value FROM test_data WHERE id = ${id}`, - )) as Array<{ value: string }>; + ); return results[0]?.value ?? null; }, getCount: async (c) => { - const results = (await c.db.execute( + const results = await c.db.execute<{ count: number }>( `SELECT COUNT(*) as count FROM test_data`, - )) as Array<{ count: number }>; + ); return results[0].count; }, rawSelectCount: async (c) => { - const results = (await c.db.execute( + const results = await c.db.execute<{ count: number }>( `SELECT COUNT(*) as count FROM test_data`, - )) as Array<{ count: number }>; + ); return results[0].count; }, insertMany: async (c, count: number) => { @@ -94,15 +94,15 @@ export const dbActorRaw = actor({ await c.db.execute( `INSERT INTO test_data (value, payload, created_at) VALUES ('payload', '${payload}', ${Date.now()})`, ); - const results = (await c.db.execute( + const results = await c.db.execute<{ id: number }>( `SELECT last_insert_rowid() as id`, - )) as Array<{ id: number }>; + ); return { id: results[0].id, size }; }, getPayloadSize: async (c, id: number) => { - const results = (await c.db.execute( + const results = await c.db.execute<{ size: number }>( `SELECT length(payload) as size FROM test_data WHERE id = ${id}`, - )) as Array<{ size: number }>; + ); return results[0]?.size ?? 0; }, repeatUpdate: async (c, id: number, count: number) => { @@ -125,9 +125,9 @@ export const dbActorRaw = actor({ await c.db.execute( `BEGIN; INSERT INTO test_data (value, payload, created_at) VALUES ('${value}', '', ${Date.now()}); UPDATE test_data SET value = '${value}-updated' WHERE id = last_insert_rowid(); COMMIT;`, ); - const results = (await c.db.execute( + const results = await c.db.execute<{ value: string }>( `SELECT value FROM test_data ORDER BY id DESC LIMIT 1`, - )) as Array<{ value: string }>; + ); return results[0]?.value ?? null; }, triggerSleep: (c) => { diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/workflow.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/workflow.ts index d74bf01a06..220652333d 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/workflow.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/workflow.ts @@ -132,18 +132,18 @@ export const workflowAccessActor = actor({ } await loopCtx.step("access-step", async () => { - await actorLoopCtx.db.execute( + await loopCtx.db.execute( `INSERT INTO workflow_access_log (created_at) VALUES (${Date.now()})`, ); - const counts = (await actorLoopCtx.db.execute( + const counts = await loopCtx.db.execute<{ count: number }>( `SELECT COUNT(*) as count FROM workflow_access_log`, - )) as Array<{ count: number }>; - const client = actorLoopCtx.client(); + ); + const client = loopCtx.client(); - actorLoopCtx.state.outsideDbError = outsideDbError; - actorLoopCtx.state.outsideClientError = outsideClientError; - actorLoopCtx.state.insideDbCount = counts[0]?.count ?? 0; - actorLoopCtx.state.insideClientAvailable = + loopCtx.state.outsideDbError = outsideDbError; + loopCtx.state.outsideClientError = outsideClientError; + loopCtx.state.insideDbCount = counts[0]?.count ?? 0; + loopCtx.state.insideClientAvailable = typeof client.workflowQueueActor.getForId === "function"; }); diff --git a/rivetkit-typescript/packages/rivetkit/src/db/config.ts b/rivetkit-typescript/packages/rivetkit/src/db/config.ts index e94c9b4dbe..50d1e7682e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/db/config.ts +++ b/rivetkit-typescript/packages/rivetkit/src/db/config.ts @@ -66,7 +66,10 @@ export type DatabaseProvider = { * Raw database client with basic exec method */ export interface RawDatabaseClient { - exec: (query: string, ...args: unknown[]) => Promise | unknown[]; + exec: = Record>( + query: string, + ...args: unknown[] + ) => Promise | TRow[]; } /** @@ -77,10 +80,12 @@ export interface DrizzleDatabaseClient { // For now, just a marker interface } -type ExecuteFunction = ( +type ExecuteFunction = < + TRow extends Record = Record, +>( query: string, ...args: unknown[] -) => Promise; +) => Promise; export type RawAccess = { /** diff --git a/rivetkit-typescript/packages/rivetkit/src/db/drizzle/mod.ts b/rivetkit-typescript/packages/rivetkit/src/db/drizzle/mod.ts index 2b3ec95b5c..a7875071f1 100644 --- a/rivetkit-typescript/packages/rivetkit/src/db/drizzle/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/db/drizzle/mod.ts @@ -163,9 +163,36 @@ export function db< const client = proxyDrizzle(callback, config); return Object.assign(client, { - execute: async (query: string, ...args: unknown[]) => { - const result = await callback(query, args, "all"); - return result.rows; + execute: async < + TRow extends Record = Record, + >( + query: string, + ...args: unknown[] + ): Promise => { + if (args.length > 0) { + const { rows, columns } = await waDb.query(query, args); + return rows.map((row: unknown[]) => { + const rowObj: Record = {}; + for (let i = 0; i < row.length; i++) { + rowObj[columns[i]] = row[i]; + } + return rowObj; + }) as TRow[]; + } + + const results: Record[] = []; + let columnNames: string[] | null = null; + await waDb.exec(query, (row: unknown[], columns: string[]) => { + if (!columnNames) { + columnNames = columns; + } + const rowObj: Record = {}; + for (let i = 0; i < row.length; i++) { + rowObj[columnNames[i]] = row[i]; + } + results.push(rowObj); + }); + return results as TRow[]; }, close: async () => { await waDb.close(); diff --git a/rivetkit-typescript/packages/rivetkit/src/db/mod.ts b/rivetkit-typescript/packages/rivetkit/src/db/mod.ts index 2f586b5b0c..1cd28de725 100644 --- a/rivetkit-typescript/packages/rivetkit/src/db/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/db/mod.ts @@ -46,8 +46,13 @@ export function db({ if (override) { // Use the override return { - execute: async (query, ...args) => { - return override.exec(query, ...args); + execute: async < + TRow extends Record = Record, + >( + query: string, + ...args: unknown[] + ): Promise => { + return await override.exec(query, ...args); }, close: async () => { // Override clients don't need cleanup @@ -64,7 +69,12 @@ export function db({ const db = await ctx.sqliteVfs.open(ctx.actorId, kvStore); return { - execute: async (query, ...args) => { + execute: async < + TRow extends Record = Record, + >( + query: string, + ...args: unknown[] + ): Promise => { if (args.length > 0) { // Use parameterized query when args are provided const { rows, columns } = await db.query(query, args); @@ -74,7 +84,7 @@ export function db({ rowObj[columns[i]] = row[i]; } return rowObj; - }); + }) as TRow[]; } // Use exec for non-parameterized queries @@ -90,7 +100,7 @@ export function db({ } results.push(rowObj); }); - return results; + return results as TRow[]; }, close: async () => { await db.close(); diff --git a/rivetkit-typescript/packages/rivetkit/tests/actor-types.test.ts b/rivetkit-typescript/packages/rivetkit/tests/actor-types.test.ts index 780670681f..2105dcdbc6 100644 --- a/rivetkit-typescript/packages/rivetkit/tests/actor-types.test.ts +++ b/rivetkit-typescript/packages/rivetkit/tests/actor-types.test.ts @@ -3,6 +3,7 @@ import { actor, event, queue } from "@/actor/mod"; import type { ActorContext, ActorContextOf } from "@/actor/contexts"; import type { ActorDefinition } from "@/actor/definition"; import type { DatabaseProviderContext } from "@/db/config"; +import { db } from "@/db/mod"; import { workflow } from "@/workflow/mod"; describe("ActorDefinition", () => { @@ -239,4 +240,21 @@ describe("ActorDefinition", () => { }); }); }); + + describe("database type inference", () => { + it("supports typed rows for c.db.execute", () => { + actor({ + state: {}, + db: db(), + actions: { + readFoo: async (c) => { + const rows = await c.db.execute<{ foo: string }>( + "SELECT foo FROM bar", + ); + expectTypeOf(rows).toEqualTypeOf>(); + }, + }, + }); + }); + }); });