Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,38 @@ export const dbActorDrizzle = actor({
await c.db.execute(
`INSERT INTO test_data (value, payload, created_at) VALUES ('${value}', '', ${Date.now()})`,
);
const results = (await c.db.execute(
const results = await c.db.execute<{ id: number }>(
`SELECT last_insert_rowid() as id`,
)) as Array<{ id: number }>;
);
return { id: results[0].id };
},
getValues: async (c) => {
const results = (await c.db.execute(
`SELECT * FROM test_data ORDER BY id`,
)) as Array<{
const results = await c.db.execute<{
id: number;
value: string;
payload: string;
created_at: number;
}>;
}>(
`SELECT * FROM test_data ORDER BY id`,
);
return results;
},
getValue: async (c, id: number) => {
const results = (await c.db.execute(
const results = await c.db.execute<{ value: string }>(
`SELECT value FROM test_data WHERE id = ${id}`,
)) as Array<{ value: string }>;
);
return results[0]?.value ?? null;
},
getCount: async (c) => {
const results = (await c.db.execute(
const results = await c.db.execute<{ count: number }>(
`SELECT COUNT(*) as count FROM test_data`,
)) as Array<{ count: number }>;
);
return results[0].count;
},
rawSelectCount: async (c) => {
const results = (await c.db.execute(
const results = await c.db.execute<{ count: number }>(
`SELECT COUNT(*) as count FROM test_data`,
)) as Array<{ count: number }>;
);
return results[0]?.count ?? 0;
},
insertMany: async (c, count: number) => {
Expand Down Expand Up @@ -88,15 +88,15 @@ export const dbActorDrizzle = actor({
await c.db.execute(
`INSERT INTO test_data (value, payload, created_at) VALUES ('payload', '${payload}', ${Date.now()})`,
);
const results = (await c.db.execute(
const results = await c.db.execute<{ id: number }>(
`SELECT last_insert_rowid() as id`,
)) as Array<{ id: number }>;
);
return { id: results[0].id, size };
},
getPayloadSize: async (c, id: number) => {
const results = (await c.db.execute(
const results = await c.db.execute<{ size: number }>(
`SELECT length(payload) as size FROM test_data WHERE id = ${id}`,
)) as Array<{ size: number }>;
);
return results[0]?.size ?? 0;
},
repeatUpdate: async (c, id: number, count: number) => {
Expand All @@ -119,9 +119,9 @@ export const dbActorDrizzle = actor({
await c.db.execute(
`BEGIN; INSERT INTO test_data (value, payload, created_at) VALUES ('${value}', '', ${Date.now()}); UPDATE test_data SET value = '${value}-updated' WHERE id = last_insert_rowid(); COMMIT;`,
);
const results = (await c.db.execute(
const results = await c.db.execute<{ value: string }>(
`SELECT value FROM test_data ORDER BY id DESC LIMIT 1`,
)) as Array<{ value: string }>;
);
return results[0]?.value ?? null;
},
triggerSleep: (c) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,38 @@ export const dbActorRaw = actor({
await c.db.execute(
`INSERT INTO test_data (value, payload, created_at) VALUES ('${value}', '', ${Date.now()})`,
);
const results = (await c.db.execute(
const results = await c.db.execute<{ id: number }>(
`SELECT last_insert_rowid() as id`,
)) as Array<{ id: number }>;
);
return { id: results[0].id };
},
getValues: async (c) => {
const results = (await c.db.execute(
`SELECT * FROM test_data ORDER BY id`,
)) as Array<{
const results = await c.db.execute<{
id: number;
value: string;
payload: string;
created_at: number;
}>;
}>(
`SELECT * FROM test_data ORDER BY id`,
);
return results;
},
getValue: async (c, id: number) => {
const results = (await c.db.execute(
const results = await c.db.execute<{ value: string }>(
`SELECT value FROM test_data WHERE id = ${id}`,
)) as Array<{ value: string }>;
);
return results[0]?.value ?? null;
},
getCount: async (c) => {
const results = (await c.db.execute(
const results = await c.db.execute<{ count: number }>(
`SELECT COUNT(*) as count FROM test_data`,
)) as Array<{ count: number }>;
);
return results[0].count;
},
rawSelectCount: async (c) => {
const results = (await c.db.execute(
const results = await c.db.execute<{ count: number }>(
`SELECT COUNT(*) as count FROM test_data`,
)) as Array<{ count: number }>;
);
return results[0].count;
},
insertMany: async (c, count: number) => {
Expand Down Expand Up @@ -94,15 +94,15 @@ export const dbActorRaw = actor({
await c.db.execute(
`INSERT INTO test_data (value, payload, created_at) VALUES ('payload', '${payload}', ${Date.now()})`,
);
const results = (await c.db.execute(
const results = await c.db.execute<{ id: number }>(
`SELECT last_insert_rowid() as id`,
)) as Array<{ id: number }>;
);
return { id: results[0].id, size };
},
getPayloadSize: async (c, id: number) => {
const results = (await c.db.execute(
const results = await c.db.execute<{ size: number }>(
`SELECT length(payload) as size FROM test_data WHERE id = ${id}`,
)) as Array<{ size: number }>;
);
return results[0]?.size ?? 0;
},
repeatUpdate: async (c, id: number, count: number) => {
Expand All @@ -125,9 +125,9 @@ export const dbActorRaw = actor({
await c.db.execute(
`BEGIN; INSERT INTO test_data (value, payload, created_at) VALUES ('${value}', '', ${Date.now()}); UPDATE test_data SET value = '${value}-updated' WHERE id = last_insert_rowid(); COMMIT;`,
);
const results = (await c.db.execute(
const results = await c.db.execute<{ value: string }>(
`SELECT value FROM test_data ORDER BY id DESC LIMIT 1`,
)) as Array<{ value: string }>;
);
return results[0]?.value ?? null;
},
triggerSleep: (c) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,18 +132,18 @@ export const workflowAccessActor = actor({
}

await loopCtx.step("access-step", async () => {
await actorLoopCtx.db.execute(
await loopCtx.db.execute(
`INSERT INTO workflow_access_log (created_at) VALUES (${Date.now()})`,
);
const counts = (await actorLoopCtx.db.execute(
const counts = await loopCtx.db.execute<{ count: number }>(
`SELECT COUNT(*) as count FROM workflow_access_log`,
)) as Array<{ count: number }>;
const client = actorLoopCtx.client();
);
const client = loopCtx.client<typeof registry>();

actorLoopCtx.state.outsideDbError = outsideDbError;
actorLoopCtx.state.outsideClientError = outsideClientError;
actorLoopCtx.state.insideDbCount = counts[0]?.count ?? 0;
actorLoopCtx.state.insideClientAvailable =
loopCtx.state.outsideDbError = outsideDbError;
loopCtx.state.outsideClientError = outsideClientError;
loopCtx.state.insideDbCount = counts[0]?.count ?? 0;
loopCtx.state.insideClientAvailable =
typeof client.workflowQueueActor.getForId === "function";
});

Expand Down
11 changes: 8 additions & 3 deletions rivetkit-typescript/packages/rivetkit/src/db/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ export type DatabaseProvider<DB extends RawAccess> = {
* Raw database client with basic exec method
*/
export interface RawDatabaseClient {
exec: (query: string, ...args: unknown[]) => Promise<unknown[]> | unknown[];
exec: <TRow extends Record<string, unknown> = Record<string, unknown>>(
query: string,
...args: unknown[]
) => Promise<TRow[]> | TRow[];
}

/**
Expand All @@ -77,10 +80,12 @@ export interface DrizzleDatabaseClient {
// For now, just a marker interface
}

type ExecuteFunction = (
type ExecuteFunction = <
TRow extends Record<string, unknown> = Record<string, unknown>,
>(
query: string,
...args: unknown[]
) => Promise<unknown[]>;
) => Promise<TRow[]>;

export type RawAccess = {
/**
Expand Down
33 changes: 30 additions & 3 deletions rivetkit-typescript/packages/rivetkit/src/db/drizzle/mod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,36 @@ export function db<
const client = proxyDrizzle<TSchema>(callback, config);

return Object.assign(client, {
execute: async (query: string, ...args: unknown[]) => {
const result = await callback(query, args, "all");
return result.rows;
execute: async <
TRow extends Record<string, unknown> = Record<string, unknown>,
>(
query: string,
...args: unknown[]
): Promise<TRow[]> => {
if (args.length > 0) {
const { rows, columns } = await waDb.query(query, args);
return rows.map((row: unknown[]) => {
const rowObj: Record<string, unknown> = {};
for (let i = 0; i < row.length; i++) {
rowObj[columns[i]] = row[i];
}
return rowObj;
}) as TRow[];
}

const results: Record<string, unknown>[] = [];
let columnNames: string[] | null = null;
await waDb.exec(query, (row: unknown[], columns: string[]) => {
if (!columnNames) {
columnNames = columns;
}
const rowObj: Record<string, unknown> = {};
for (let i = 0; i < row.length; i++) {
rowObj[columnNames[i]] = row[i];
}
results.push(rowObj);
});
return results as TRow[];
},
close: async () => {
await waDb.close();
Expand Down
20 changes: 15 additions & 5 deletions rivetkit-typescript/packages/rivetkit/src/db/mod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,13 @@ export function db({
if (override) {
// Use the override
return {
execute: async (query, ...args) => {
return override.exec(query, ...args);
execute: async <
TRow extends Record<string, unknown> = Record<string, unknown>,
>(
query: string,
...args: unknown[]
): Promise<TRow[]> => {
return await override.exec<TRow>(query, ...args);
},
close: async () => {
// Override clients don't need cleanup
Expand All @@ -64,7 +69,12 @@ export function db({
const db = await ctx.sqliteVfs.open(ctx.actorId, kvStore);

return {
execute: async (query, ...args) => {
execute: async <
TRow extends Record<string, unknown> = Record<string, unknown>,
>(
query: string,
...args: unknown[]
): Promise<TRow[]> => {
if (args.length > 0) {
// Use parameterized query when args are provided
const { rows, columns } = await db.query(query, args);
Expand All @@ -74,7 +84,7 @@ export function db({
rowObj[columns[i]] = row[i];
}
return rowObj;
});
}) as TRow[];
}

// Use exec for non-parameterized queries
Expand All @@ -90,7 +100,7 @@ export function db({
}
results.push(rowObj);
});
return results;
return results as TRow[];
},
close: async () => {
await db.close();
Expand Down
18 changes: 18 additions & 0 deletions rivetkit-typescript/packages/rivetkit/tests/actor-types.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { actor, event, queue } from "@/actor/mod";
import type { ActorContext, ActorContextOf } from "@/actor/contexts";
import type { ActorDefinition } from "@/actor/definition";
import type { DatabaseProviderContext } from "@/db/config";
import { db } from "@/db/mod";
import { workflow } from "@/workflow/mod";

describe("ActorDefinition", () => {
Expand Down Expand Up @@ -239,4 +240,21 @@ describe("ActorDefinition", () => {
});
});
});

describe("database type inference", () => {
it("supports typed rows for c.db.execute", () => {
actor({
state: {},
db: db(),
actions: {
readFoo: async (c) => {
const rows = await c.db.execute<{ foo: string }>(
"SELECT foo FROM bar",
);
expectTypeOf(rows).toEqualTypeOf<Array<{ foo: string }>>();
},
},
});
});
});
});
Loading