Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions packages/language/src/ast.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ declare module './ast' {
$resolvedParam?: AttributeParam;
}

interface BinaryExpr {
/**
* Optional iterator binding for collection predicates
*/
binding?: string;
}

export interface DataModel {
/**
* All fields including those marked with `@ignore`
Expand Down
8 changes: 6 additions & 2 deletions packages/language/src/generated/ast.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ export function isMemberAccessTarget(item: unknown): item is MemberAccessTarget
return reflection.isInstance(item, MemberAccessTarget);
}

export type ReferenceTarget = DataField | EnumField | FunctionParam;
export type ReferenceTarget = BinaryExpr | DataField | EnumField | FunctionParam;

export const ReferenceTarget = 'ReferenceTarget';

Expand Down Expand Up @@ -256,6 +256,7 @@ export function isAttributeParamType(item: unknown): item is AttributeParamType
export interface BinaryExpr extends langium.AstNode {
readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | FieldInitializer | FunctionDecl | MemberAccessExpr | ReferenceArg | UnaryExpr;
readonly $type: 'BinaryExpr';
binding?: RegularID;
left: Expression;
operator: '!' | '!=' | '&&' | '<' | '<=' | '==' | '>' | '>=' | '?' | '^' | 'in' | '||';
right: Expression;
Expand Down Expand Up @@ -826,7 +827,6 @@ export class ZModelAstReflection extends langium.AbstractAstReflection {
protected override computeIsSubtype(subtype: string, supertype: string): boolean {
switch (subtype) {
case ArrayExpr:
case BinaryExpr:
case MemberAccessExpr:
case NullExpr:
case ObjectExpr:
Expand All @@ -843,6 +843,9 @@ export class ZModelAstReflection extends langium.AbstractAstReflection {
case Procedure: {
return this.isSubtype(AbstractDeclaration, supertype);
}
case BinaryExpr: {
return this.isSubtype(Expression, supertype) || this.isSubtype(ReferenceTarget, supertype);
}
case BooleanLiteral:
case NumberLiteral:
case StringLiteral: {
Expand Down Expand Up @@ -973,6 +976,7 @@ export class ZModelAstReflection extends langium.AbstractAstReflection {
return {
name: BinaryExpr,
properties: [
{ name: 'binding' },
{ name: 'left' },
{ name: 'operator' },
{ name: 'right' }
Expand Down
28 changes: 28 additions & 0 deletions packages/language/src/generated/grammar.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,28 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"$type": "Keyword",
"value": "["
},
{
"$type": "Group",
"elements": [
{
"$type": "Assignment",
"feature": "binding",
"operator": "=",
"terminal": {
"$type": "RuleCall",
"rule": {
"$ref": "#/rules@51"
},
"arguments": []
}
},
{
"$type": "Keyword",
"value": ","
}
],
"cardinality": "?"
},
{
"$type": "Assignment",
"feature": "right",
Expand Down Expand Up @@ -3996,6 +4018,12 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"typeRef": {
"$ref": "#/rules@45"
}
},
{
"$type": "SimpleType",
"typeRef": {
"$ref": "#/rules@29/definition/elements@1/elements@0/inferredType"
}
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import {
DataFieldAttribute,
DataModelAttribute,
InternalAttribute,
ReferenceExpr,
isArrayExpr,
isAttribute,
isConfigArrayExpr,
Expand Down Expand Up @@ -491,9 +490,16 @@ function isValidAttributeTarget(attrDecl: Attribute, targetDecl: DataField) {
return true;
}

const fieldTypes = (targetField.args[0].value as ArrayExpr).items.map(
(item) => (item as ReferenceExpr).target.ref?.name,
);
const fieldTypes = (targetField.args[0].value as ArrayExpr).items
.map((item) => {
if (!isReferenceExpr(item)) {
return undefined;
}

const ref = item.target.ref;
return ref && 'name' in ref && typeof ref.name === 'string' ? ref.name : undefined;
})
.filter((name): name is string => !!name);

let allowed = false;
for (const allowedType of fieldTypes) {
Expand Down
8 changes: 5 additions & 3 deletions packages/language/src/zmodel-code-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,15 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')}${

const { left: isLeftParenthesis, right: isRightParenthesis } = this.isParenthesesNeededForBinaryExpr(ast);

const collectionPredicate = isCollectionPredicate
? `[${ast.binding ? `${ast.binding}, ${rightExpr}` : rightExpr}]`
: rightExpr;

return `${isLeftParenthesis ? '(' : ''}${this.generate(ast.left)}${
isLeftParenthesis ? ')' : ''
}${isCollectionPredicate ? '' : this.binaryExprSpace}${operator}${
isCollectionPredicate ? '' : this.binaryExprSpace
}${isRightParenthesis ? '(' : ''}${
isCollectionPredicate ? `[${rightExpr}]` : rightExpr
}${isRightParenthesis ? ')' : ''}`;
}${isRightParenthesis ? '(' : ''}${collectionPredicate}${isRightParenthesis ? ')' : ''}`;
}

@gen(ReferenceExpr)
Expand Down
37 changes: 29 additions & 8 deletions packages/language/src/zmodel-linker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
DataModel,
Enum,
EnumField,
isBinaryExpr,
type ExpressionType,
FunctionDecl,
FunctionParam,
Expand Down Expand Up @@ -121,7 +122,13 @@ export class ZModelLinker extends DefaultLinker {
const target = provider(reference.$refText);
if (target) {
reference._ref = target;
reference._nodeDescription = this.descriptions.createDescription(target, target.name, document);
let targetName = reference.$refText;
if ('name' in target && typeof target.name === 'string') {
targetName = target.name;
} else if ('binding' in target && typeof (target as { binding?: unknown }).binding === 'string') {
targetName = (target as { binding: string }).binding;
}
reference._nodeDescription = this.descriptions.createDescription(target, targetName, document);

// Add the reference to the document's array of references
document.references.push(reference);
Expand Down Expand Up @@ -249,13 +256,24 @@ export class ZModelLinker extends DefaultLinker {

private resolveReference(node: ReferenceExpr, document: LangiumDocument<AstNode>, extraScopes: ScopeProvider[]) {
this.resolveDefault(node, document, extraScopes);

if (node.target.ref) {
// resolve type
if (node.target.ref.$type === EnumField) {
this.resolveToBuiltinTypeOrDecl(node, node.target.ref.$container);
} else {
this.resolveToDeclaredType(node, (node.target.ref as DataField | FunctionParam).type);
const target = node.target.ref;

if (target) {
if (isBinaryExpr(target) && ['?', '!', '^'].includes(target.operator)) {
const collectionType = target.left.$resolvedType;
if (collectionType?.decl) {
node.$resolvedType = {
decl: collectionType.decl,
array: false,
nullable: collectionType.nullable,
};
}
} else if (target.$type === EnumField) {
this.resolveToBuiltinTypeOrDecl(node, target.$container);
} else if (isDataField(target)) {
this.resolveToDeclaredType(node, target.type);
} else if (target.$type === FunctionParam && (target as FunctionParam).type) {
this.resolveToDeclaredType(node, (target as FunctionParam).type);
}
}
}
Expand Down Expand Up @@ -506,6 +524,9 @@ export class ZModelLinker extends DefaultLinker {
//#region Utils

private resolveToDeclaredType(node: AstNode, type: FunctionParamType | DataFieldType) {
if (!type) {
return;
}
let nullable = false;
if (isDataFieldType(type)) {
nullable = type.optional;
Expand Down
21 changes: 21 additions & 0 deletions packages/language/src/zmodel-scope.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
StreamScope,
UriUtils,
interruptAndCheck,
stream,
type AstNode,
type AstNodeDescription,
type LangiumCoreServices,
Expand All @@ -18,7 +19,9 @@ import {
import { match } from 'ts-pattern';
import {
BinaryExpr,
Expression,
MemberAccessExpr,
isBinaryExpr,
isDataField,
isDataModel,
isEnumField,
Expand Down Expand Up @@ -145,6 +148,9 @@ export class ZModelScopeProvider extends DefaultScopeProvider {
.when(isReferenceExpr, (operand) => {
// operand is a reference, it can only be a model/type-def field
const ref = operand.target.ref;
if (isBinaryExpr(ref) && isCollectionPredicate(ref)) {
return this.createScopeForCollectionElement(ref.left, globalScope, allowTypeDefScope);
}
if (isDataField(ref)) {
return this.createScopeForContainer(ref.type.reference?.ref, globalScope, allowTypeDefScope);
}
Expand Down Expand Up @@ -188,6 +194,21 @@ export class ZModelScopeProvider extends DefaultScopeProvider {
// // typedef's fields are only added to the scope if the access starts with `auth().`
const allowTypeDefScope = isAuthOrAuthMemberAccess(collection);

const collectionScope = this.createScopeForCollectionElement(collection, globalScope, allowTypeDefScope);

if (collectionPredicate.binding) {
const description = this.descriptions.createDescription(
collectionPredicate,
collectionPredicate.binding,
collectionPredicate.$document!,
);
return new StreamScope(stream([description]), collectionScope);
}

return collectionScope;
}

private createScopeForCollectionElement(collection: Expression, globalScope: Scope, allowTypeDefScope: boolean) {
return match(collection)
.when(isReferenceExpr, (expr) => {
// collection is a reference - model or typedef field
Expand Down
4 changes: 2 additions & 2 deletions packages/language/src/zmodel.langium
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ ConfigArrayExpr:
ConfigExpr:
LiteralExpr | InvocationExpr | ConfigArrayExpr;

type ReferenceTarget = FunctionParam | DataField | EnumField;
type ReferenceTarget = FunctionParam | DataField | EnumField | BinaryExpr;

ThisExpr:
value='this';
Expand Down Expand Up @@ -113,7 +113,7 @@ CollectionPredicateExpr infers Expression:
MemberAccessExpr (
{infer BinaryExpr.left=current}
operator=('?'|'!'|'^')
'[' right=Expression ']'
'[' (binding=RegularID ',')? right=Expression ']'
)*;

InExpr infers Expression:
Expand Down
44 changes: 44 additions & 0 deletions packages/language/test/expression-validation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,48 @@ describe('Expression Validation Tests', () => {
'incompatible operand types',
);
});

it('should allow collection predicate with iterator binding', async () => {
await loadSchema(`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id Int @id
memberships Membership[]
@@allow('read', memberships?[m, m.tenantId == id])
}

model Membership {
id Int @id
tenantId Int
user User @relation(fields: [userId], references: [id])
userId Int
}
`);
});

it('should keep supporting unbound collection predicate syntax', async () => {
await loadSchema(`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id Int @id
memberships Membership[]
@@allow('read', memberships?[tenantId == id])
}

model Membership {
id Int @id
tenantId Int
user User @relation(fields: [userId], references: [id])
userId Int
}
`);
});
});
31 changes: 29 additions & 2 deletions packages/plugins/policy/src/expression-evaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
type ExpressionEvaluatorContext = {
auth?: any;
thisValue?: any;
scope?: Record<string, any>;
};

/**
Expand Down Expand Up @@ -64,6 +65,9 @@ export class ExpressionEvaluator {
}

private evaluateField(expr: FieldExpression, context: ExpressionEvaluatorContext): any {
if (context.scope && expr.field in context.scope) {
return context.scope[expr.field];
}
return context.thisValue?.[expr.field];
}

Expand Down Expand Up @@ -113,15 +117,38 @@ export class ExpressionEvaluator {
invariant(Array.isArray(left), 'expected array');

return match(op)
.with('?', () => left.some((item: any) => this.evaluate(expr.right, { ...context, thisValue: item })))
.with('!', () => left.every((item: any) => this.evaluate(expr.right, { ...context, thisValue: item })))
.with('?', () =>
left.some((item: any) =>
this.evaluate(expr.right, {
...context,
thisValue: item,
scope: expr.binding
? { ...(context.scope ?? {}), [expr.binding]: item }
: context.scope,
}),
),
)
.with('!', () =>
left.every((item: any) =>
this.evaluate(expr.right, {
...context,
thisValue: item,
scope: expr.binding
? { ...(context.scope ?? {}), [expr.binding]: item }
: context.scope,
}),
),
)
.with(
'^',
() =>
!left.some((item: any) =>
this.evaluate(expr.right, {
...context,
thisValue: item,
scope: expr.binding
? { ...(context.scope ?? {}), [expr.binding]: item }
: context.scope,
}),
),
)
Expand Down
Loading
Loading