Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions packages/language/res/stdlib.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ function nanoid(length: Int?, format: String?): String {
function ulid(format: String?): String {
} @@@expressionContext([DefaultValue])

/**
* Generates a custom identifier. The ORM client must be initialized with an
* implementation of this function.
*/
function customId(length: Int?): String {
} @@@expressionContext([DefaultValue])

/**
* Creates a sequence of integers in the underlying database and assign the incremented
* values to the ID values of the created records based on the sequence.
Expand Down
14 changes: 14 additions & 0 deletions packages/language/src/validators/function-invocation-validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,20 @@ export default class FunctionInvocationValidator implements AstValidator<Express
}
}

@func('customId')
private _checkCustomId(expr: InvocationExpr, accept: ValidationAcceptor) {
// first argument must be positive if provided
const lengthArg = expr.args[0]?.value;
if (lengthArg) {
const length = getLiteral<number>(lengthArg);
if (length !== undefined && length <= 0) {
accept('error', 'first argument must be a positive number', {
node: expr.args[0]!,
});
}
}
}

@func('auth')
private _checkAuth(expr: InvocationExpr, accept: ValidationAcceptor) {
if (!expr.$resolvedType) {
Expand Down
32 changes: 32 additions & 0 deletions packages/language/test/function-invocation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -414,4 +414,36 @@ describe('Function Invocation Tests', () => {
);
});
});

describe('customId() length validation', () => {
it('should reject non-positive lengths', async () => {
await loadSchemaWithError(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(customId(0))
}
`,
'first argument must be a positive number',
);

await loadSchemaWithError(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(customId(-1))
}
`,
'first argument must be a positive number',
);
});
});
});
20 changes: 18 additions & 2 deletions packages/orm/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import { NUMERIC_FIELD_TYPES } from '../../constants';
import { TransactionIsolationLevel, type ClientContract, type CRUD } from '../../contract';
import type { FindArgs, SelectIncludeOmit, WhereInput } from '../../crud-types';
import {
createConfigError,
createDBQueryError,
createInternalError,
createInvalidInputError,
Expand Down Expand Up @@ -1044,7 +1045,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
}
if (!(field in data)) {
if (typeof fieldDef?.default === 'object' && 'kind' in fieldDef.default) {
const generated = this.evalGenerator(fieldDef.default);
const generated = this.evalGenerator(fieldDef.default, modelDef.name);
if (generated !== undefined) {
values[field] = this.dialect.transformInput(
generated,
Expand Down Expand Up @@ -1072,7 +1073,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
return values;
}

private evalGenerator(defaultValue: Expression) {
private evalGenerator(defaultValue: Expression, model: string) {
if (ExpressionUtils.isCall(defaultValue)) {
const firstArgVal =
defaultValue.args?.[0] && ExpressionUtils.isLiteral(defaultValue.args[0])
Expand All @@ -1095,6 +1096,21 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
return this.formatGeneratedValue(generated, defaultValue.args?.[1]);
})
.with('ulid', () => this.formatGeneratedValue(ulid(), defaultValue.args?.[0]))
.with('customId', () => {
if (!this.client.$options.customId) {
throw createConfigError('"customId" implementation not provided');
}
const length = typeof firstArgVal === 'number' ? firstArgVal : undefined;
const generated = this.client.$options.customId({
client: this.client,
model: model as GetModels<Schema>,
length,
});
if (!generated || typeof generated !== 'string') {
throw createConfigError('"customId" must return a non-empty string');
}
return generated;
})
.otherwise(() => undefined);
} else if (
ExpressionUtils.isMember(defaultValue) &&
Expand Down
25 changes: 25 additions & 0 deletions packages/orm/src/client/options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,25 @@ export type ZModelFunction<Schema extends SchemaDef> = (
context: ZModelFunctionContext<Schema>,
) => Expression<unknown>;

export type CustomIdFunctionContext<Schema extends SchemaDef> = {
/**
* ZenStack client instance.
*/
client: ClientContract<Schema>;

/**
* The model for which the ID should be generated.
*/
model: GetModels<Schema>;

/**
* The length of the ID as requested by the schema.
*/
length?: number;
};

export type CustomIdFunction<Schema extends SchemaDef> = (ctx: CustomIdFunctionContext<Schema>) => string;

/**
* ZenStack client options.
*/
Expand Down Expand Up @@ -82,6 +101,12 @@ export type ClientOptions<Schema extends SchemaDef> = {
*/
validateInput?: boolean;

/**
* Implementation of a custom ID generation function, which is called from ZModel as
* `@default(customId())`.
*/
customId?: CustomIdFunction<Schema>;

/**
* Options for omitting fields in ORM query results.
*/
Expand Down
180 changes: 180 additions & 0 deletions tests/e2e/orm/client-api/custom-id.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import { createTestClient } from '@zenstackhq/testtools';
import { describe, expect, it } from 'vitest';

const schema = `
model User {
uid String @id @default(customId())
posts Post[]
}

model Post {
pid String @id @default(customId())
userId String?
user User? @relation(fields: [userId], references: [uid])
comments Comment[]
}

model Comment {
cid String @id @default(customId())
postId String?
post Post? @relation(fields: [postId], references: [pid])
}
`;

describe('customId', () => {
it('works with no arguments', async () => {
let client = await createTestClient(schema, {
customId: ({ model, length, client }) => `${model}.${length ?? 16}.${client.$auth!['uid']}`,
});

client = client.$setAuth({
uid: '1',
});

await expect(client.user.create({ data: {} })).resolves.toMatchObject({
uid: 'User.16.1',
});

await expect(client.post.create({ data: {} })).resolves.toMatchObject({
pid: 'Post.16.1',
});

await expect(client.comment.create({ data: {} })).resolves.toMatchObject({
cid: 'Comment.16.1',
});
});

it('works with arguments', async () => {
const schema = `
model User {
uid String @id @default(customId(8))
posts Post[]
}

model Post {
pid String @id @default(customId(8))
userId String?
user User? @relation(fields: [userId], references: [uid])
comments Comment[]
}

model Comment {
cid String @id @default(customId(8))
postId String?
post Post? @relation(fields: [postId], references: [pid])
}
`;

let client = await createTestClient(schema, {
customId: ({ model, length, client }) => `${model}.${length}.${client.$auth!['uid']}`,
});

client = client.$setAuth({
uid: '1',
});

await expect(client.user.create({ data: {} })).resolves.toMatchObject({
uid: 'User.8.1',
});

await expect(client.post.create({ data: {} })).resolves.toMatchObject({
pid: 'Post.8.1',
});

await expect(client.comment.create({ data: {} })).resolves.toMatchObject({
cid: 'Comment.8.1',
});
});

it('works with nested', async () => {
let client = await createTestClient(schema, {
customId: ({ model, length, client }) => `${model}.${length ?? 16}.${client.$auth!['uid']}`,
});

client = client.$setAuth({
uid: '1',
});

await expect(client.user.create({
data: {
posts: {
create: {},
},
},
})).resolves.toMatchObject({
uid: 'User.16.1',
});

await expect(client.post.findUnique({
where: {
pid: 'Post.16.1',
}
})).resolves.toBeTruthy();
});

it('works with deeply nested', async () => {
let client = await createTestClient(schema, {
customId: ({ model, length, client }) => `${model}.${length ?? 16}.${client.$auth!['uid']}`,
});

client = client.$setAuth({
uid: '1',
});

await expect(client.user.create({
data: {
posts: {
create: {
comments: {
create: {},
},
},
},
},
})).resolves.toMatchObject({
uid: 'User.16.1',
});

await expect(client.post.findUnique({
where: {
pid: 'Post.16.1',
}
})).resolves.toBeTruthy();

await expect(client.comment.findUnique({
where: {
cid: 'Comment.16.1',
}
})).resolves.toBeTruthy();
});

it('rejects without an implementation', async () => {
const client = await createTestClient(schema);
await expect(client.user.create({ data: {} })).rejects.toThrowError('implementation not provided');
});

it('rejects without a valid implementation (undefined)', async () => {
// @ts-expect-error
const client = await createTestClient(schema, {
customId: () => undefined,
});
// @ts-expect-error
await expect(client.user.create({ data: {} })).rejects.toThrowError('non-empty string');
});

it('rejects without a valid implementation (empty string)', async () => {
const client = await createTestClient(schema, {
customId: () => '',
});
await expect(client.user.create({ data: {} })).rejects.toThrowError('non-empty string');
});

it('rejects without a valid implementation (non-string)', async () => {
// @ts-expect-error
const client = await createTestClient(schema, {
customId: () => 1,
});
// @ts-expect-error
await expect(client.user.create({ data: {} })).rejects.toThrowError('non-empty string');
});
});
Loading