diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 000000000..3ba8779da --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,23 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: '' +assignees: '' +--- + +**Description and expected behavior** +A clear and concise description of what the bug is and what's the expected behavior. + +**Screenshots** +If applicable, add screenshots to help explain your problem. + +**Environment (please complete the following information):** + +- ZenStack version: [e.g., 3.1.0] +- Database type: [e.g. Postgresql] +- Node.js/Bun version: [e.g., 20.0.0] +- Package manager: [e.g., npm, yarn, pnpm] + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 000000000..9ba516b23 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,19 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '[Feature Request] ' +labels: '' +assignees: '' +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 000000000..d9c4f07a1 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,93 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: 'CodeQL' + +on: + push: + branches: ['main', 'dev', 'release/*'] + pull_request: + branches: ['main', 'dev', 'release/*'] + schedule: + - cron: '25 0 * * 6' + +jobs: + analyze: + name: Analyze (${{ matrix.language }}) + # Runner size impacts CodeQL analysis time. To learn more, please see: + # - https://gh.io/recommended-hardware-resources-for-running-codeql + # - https://gh.io/supported-runners-and-hardware-resources + # - https://gh.io/using-larger-runners (GitHub.com only) + # Consider using larger runners or machines with greater resources for possible analysis time improvements. + runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }} + timeout-minutes: ${{ (matrix.language == 'swift' && 120) || 360 }} + permissions: + # required for all workflows + security-events: write + + # required to fetch internal or private CodeQL packs + packages: read + + # only required for workflows in private repositories + actions: read + contents: read + + strategy: + fail-fast: false + matrix: + include: + - language: javascript-typescript + build-mode: none + # CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift' + # Use `c-cpp` to analyze code written in C, C++ or both + # Use 'java-kotlin' to analyze code written in Java, Kotlin or both + # Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both + # To learn more about changing the languages that are analyzed or customizing the build mode for your analysis, + # see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning. + # If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how + # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + config-file: ./.github/workflows/config/codeql-config.yml + languages: ${{ matrix.language }} + build-mode: ${{ matrix.build-mode }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + + # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs + # queries: security-extended,security-and-quality + + # If the analyze step fails for one of the languages you are analyzing with + # "We were unable to automatically build your code", modify the matrix above + # to set the build mode to "manual" for that language. Then modify this step + # to build your code. + # â„šī¸ Command-line programs to run using the OS shell. + # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun + - if: matrix.build-mode == 'manual' + run: | + echo 'If you are using a "manual" build mode for one or more of the' \ + 'languages you are analyzing, replace this with the commands to build' \ + 'your code, for example:' + echo ' make bootstrap' + echo ' make release' + exit 1 + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 + with: + category: '/language:${{matrix.language}}' diff --git a/.github/workflows/config/codeql-config.yml b/.github/workflows/config/codeql-config.yml new file mode 100644 index 000000000..c72655d1e --- /dev/null +++ b/.github/workflows/config/codeql-config.yml @@ -0,0 +1,7 @@ +paths: + - 'packages' +paths-ignore: + - '**/*.test.js' + - '**/*.test.ts' + - '**/*.test.tsx' + - '**/__tests__/**' diff --git a/.vscode/launch.json b/.vscode/launch.json index 09ccbd596..df6fbc3c6 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -11,7 +11,7 @@ "skipFiles": ["/**"], "type": "node", "args": ["generate"], - "cwd": "${workspaceFolder}/samples/blog" + "cwd": "${workspaceFolder}/samples/orm" }, { "name": "Debug with TSX", diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..212eb48c9 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +- Demonstrating empathy and kindness toward other people +- Being respectful of differing opinions, viewpoints, and experiences +- Giving and gracefully accepting constructive feedback +- Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +- Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +- The use of sexualized language or imagery, and sexual attention or + advances of any kind +- Trolling, insulting or derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or email + address, without their explicit permission +- Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +contact@zenstack.dev. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/README.md b/README.md index 92f24ca4d..859364b29 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ - + @@ -37,7 +37,7 @@ ZenStack is a TypeScript database toolkit for developing full-stack or backend N # What's New in V3 -ZenStack V3 is a major rewrite of [V2](https://github.com/zenstackhq/zenstack). It replaced Prisma ORM with its own ORM engine built on top of [Kysely](https://kysely.dev) while keeping a Prisma-compatible query API. This architecture change brings the level of flexibility that we couldn't imagine in previous versions. Please check [this blog post](https://zenstack.dev/blog/next-chapter-1) for why we made this bold decision. +ZenStack V3 is a major rewrite. It replaced Prisma ORM with its own ORM engine built on top of [Kysely](https://kysely.dev) while keeping a Prisma-compatible query API. This architecture change brings the level of flexibility that we couldn't imagine in previous versions. Please check [this blog post](https://zenstack.dev/blog/next-chapter-1) for why we made this bold decision. Even without using advanced features, ZenStack offers the following benefits as a drop-in replacement to Prisma: diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..2d74e8769 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,13 @@ +# Security Policy + +## Supported Versions + +| Version | Supported | +| ------- | ------------------ | +| >=2.0.0 | :white_check_mark: | + +## Reporting a Vulnerability + +Please send an email to contact@zenstack.dev. We'll actively watch, verify, and fix them with high priority. + +Thank you for helping us make a better project! diff --git a/packages/cli/package.json b/packages/cli/package.json index 750de5c68..ceb53e2fa 100644 --- a/packages/cli/package.json +++ b/packages/cli/package.json @@ -37,6 +37,7 @@ }, "dependencies": { "@zenstackhq/common-helpers": "workspace:*", + "@zenstackhq/schema": "workspace:*", "@zenstackhq/language": "workspace:*", "@zenstackhq/orm": "workspace:*", "@zenstackhq/sdk": "workspace:*", diff --git a/packages/cli/src/actions/action-utils.ts b/packages/cli/src/actions/action-utils.ts index d2e0ca2e9..c33a81d64 100644 --- a/packages/cli/src/actions/action-utils.ts +++ b/packages/cli/src/actions/action-utils.ts @@ -1,8 +1,9 @@ -import { loadDocument } from '@zenstackhq/language'; -import { isDataSource } from '@zenstackhq/language/ast'; +import { type ZModelServices, loadDocument } from '@zenstackhq/language'; +import { type Model, isDataSource } from '@zenstackhq/language/ast'; import { PrismaSchemaGenerator } from '@zenstackhq/sdk'; import colors from 'colors'; import fs from 'node:fs'; +import { createRequire } from 'node:module'; import path from 'node:path'; import { CliError } from '../cli-error'; @@ -41,8 +42,22 @@ export function getSchemaFile(file?: string) { } } -export async function loadSchemaDocument(schemaFile: string) { - const loadResult = await loadDocument(schemaFile); +export async function loadSchemaDocument( + schemaFile: string, + opts?: { mergeImports?: boolean; returnServices?: false }, +): Promise; +export async function loadSchemaDocument( + schemaFile: string, + opts: { returnServices: true; mergeImports?: boolean }, +): Promise<{ model: Model; services: ZModelServices }>; +export async function loadSchemaDocument( + schemaFile: string, + opts: { returnServices?: boolean; mergeImports?: boolean } = {}, +) { + const returnServices = opts.returnServices ?? false; + const mergeImports = opts.mergeImports ?? true; + + const loadResult = await loadDocument(schemaFile, [], mergeImports); if (!loadResult.success) { loadResult.errors.forEach((err) => { console.error(colors.red(err)); @@ -52,6 +67,9 @@ export async function loadSchemaDocument(schemaFile: string) { loadResult.warnings.forEach((warn) => { console.warn(colors.yellow(warn)); }); + + if (returnServices) return { model: loadResult.model, services: loadResult.services }; + return loadResult.model; } @@ -125,10 +143,10 @@ function findUp( } const target = names.find((name) => fs.existsSync(path.join(cwd, name))); if (multiple === false && target) { - return path.join(cwd, target) as FindUpResult; + return path.resolve(cwd, target) as FindUpResult; } if (target) { - result.push(path.join(cwd, target)); + result.push(path.resolve(cwd, target)); } const up = path.resolve(cwd, '..'); if (up === cwd) { @@ -156,3 +174,45 @@ export function getOutputPath(options: { output?: string }, schemaFile: string) return path.dirname(schemaFile); } } +export async function getZenStackPackages( + searchPath: string, +): Promise> { + const pkgJsonFile = findUp(['package.json'], searchPath, false); + if (!pkgJsonFile) { + return []; + } + + let pkgJson: { + dependencies?: Record; + devDependencies?: Record; + }; + try { + pkgJson = JSON.parse(fs.readFileSync(pkgJsonFile, 'utf8')); + } catch { + return []; + } + + const packages = Array.from( + new Set( + [...Object.keys(pkgJson.dependencies ?? {}), ...Object.keys(pkgJson.devDependencies ?? {})].filter((p) => + p.startsWith('@zenstackhq/'), + ), + ), + ).sort(); + + const require = createRequire(pkgJsonFile); + + const result = packages.map((pkg) => { + try { + const depPkgJson = require(`${pkg}/package.json`); + if (depPkgJson.private) { + return undefined; + } + return { pkg, version: depPkgJson.version as string }; + } catch { + return { pkg, version: undefined }; + } + }); + + return result.filter((p) => !!p); +} diff --git a/packages/cli/src/actions/db.ts b/packages/cli/src/actions/db.ts index 3d0108374..f6fda3844 100644 --- a/packages/cli/src/actions/db.ts +++ b/packages/cli/src/actions/db.ts @@ -1,25 +1,54 @@ +import { formatDocument, ZModelCodeGenerator } from '@zenstackhq/language'; +import { DataModel, Enum, type Model } from '@zenstackhq/language/ast'; +import colors from 'colors'; import fs from 'node:fs'; +import path from 'node:path'; +import ora from 'ora'; import { execPrisma } from '../utils/exec-utils'; -import { generateTempPrismaSchema, getSchemaFile, handleSubProcessError, requireDataSourceUrl } from './action-utils'; +import { + generateTempPrismaSchema, + getSchemaFile, + handleSubProcessError, + loadSchemaDocument, + requireDataSourceUrl, +} from './action-utils'; +import { consolidateEnums, syncEnums, syncRelation, syncTable, type Relation } from './pull'; +import { providers as pullProviders } from './pull/provider'; +import { getDatasource, getDbName, getRelationFieldsKey, getRelationFkName, isDatabaseManagedAttribute } from './pull/utils'; +import type { DataSourceProviderType } from '@zenstackhq/schema'; +import { CliError } from '../cli-error'; -type Options = { +type PushOptions = { schema?: string; acceptDataLoss?: boolean; forceReset?: boolean; }; +export type PullOptions = { + schema?: string; + output?: string; + modelCasing: 'pascal' | 'camel' | 'snake' | 'none'; + fieldCasing: 'pascal' | 'camel' | 'snake' | 'none'; + alwaysMap: boolean; + quote: 'single' | 'double'; + indent: number; +}; + /** * CLI action for db related commands */ -export async function run(command: string, options: Options) { +export async function run(command: string, options: any) { switch (command) { case 'push': await runPush(options); break; + case 'pull': + await runPull(options); + break; } } -async function runPush(options: Options) { +async function runPush(options: PushOptions) { const schemaFile = getSchemaFile(options.schema); // validate datasource url exists @@ -49,3 +78,599 @@ async function runPush(options: Options) { } } } + +async function runPull(options: PullOptions) { + const spinner = ora(); + try { + const schemaFile = getSchemaFile(options.schema); + + // Determine early if `--out` is a single file output (combined schema) or a directory export. + const outPath = options.output ? path.resolve(options.output) : undefined; + const treatAsFile = + !!outPath && + ((fs.existsSync(outPath) && fs.lstatSync(outPath).isFile()) || path.extname(outPath) !== ''); + + const { model, services } = await loadSchemaDocument(schemaFile, { + returnServices: true, + mergeImports: treatAsFile, + }); + + const SUPPORTED_PROVIDERS = Object.keys(pullProviders) as DataSourceProviderType[]; + const datasource = getDatasource(model); + + if (!SUPPORTED_PROVIDERS.includes(datasource.provider)) { + throw new CliError(`Unsupported datasource provider: ${datasource.provider}`); + } + + const provider = pullProviders[datasource.provider]; + + if (!provider) { + throw new CliError(`No introspection provider found for: ${datasource.provider}`); + } + + spinner.start('Introspecting database...'); + const { enums, tables } = await provider.introspect(datasource.url, { schemas: datasource.allSchemas, modelCasing: options.modelCasing }); + spinner.succeed('Database introspected'); + + console.log(colors.blue('Syncing schema...')); + + const newModel: Model = { + $type: 'Model', + $container: undefined, + $containerProperty: undefined, + $containerIndex: undefined, + declarations: [...model.declarations.filter((d) => ['DataSource'].includes(d.$type))], + imports: model.imports, + }; + + syncEnums({ + dbEnums: enums, + model: newModel, + services, + options, + defaultSchema: datasource.defaultSchema, + oldModel: model, + provider, + }); + + const resolvedRelations: Relation[] = []; + for (const table of tables) { + const relations = syncTable({ + table, + model: newModel, + provider, + services, + options, + defaultSchema: datasource.defaultSchema, + oldModel: model, + }); + resolvedRelations.push(...relations); + } + // sync relation fields + for (const relation of resolvedRelations) { + const similarRelations = resolvedRelations.filter((rr) => { + return ( + rr !== relation && + ((rr.schema === relation.schema && + rr.table === relation.table && + rr.references.schema === relation.references.schema && + rr.references.table === relation.references.table) || + (rr.schema === relation.references.schema && + rr.columns[0] === relation.references.columns[0] && + rr.references.schema === relation.schema && + rr.references.table === relation.table)) + ); + }).length; + const selfRelation = + relation.references.schema === relation.schema && relation.references.table === relation.table; + syncRelation({ + model: newModel, + relation, + services, + options, + selfRelation, + similarRelations: similarRelations, + }); + } + + // Consolidate per-column enums (e.g., MySQL's synthetic UserStatus/GroupStatus) + // back to shared enums from the original schema (e.g., Status) + consolidateEnums({ newModel, oldModel: model }); + + console.log(colors.blue('Schema synced')); + + const baseDir = path.dirname(path.resolve(schemaFile)); + const baseDirUrlPath = new URL(`file://${baseDir}`).pathname; + const docs = services.shared.workspace.LangiumDocuments.all + .filter(({ uri }) => uri.path.toLowerCase().startsWith(baseDirUrlPath.toLowerCase())) + .toArray(); + const docsSet = new Set(docs.map((d) => d.uri.toString())); + + console.log(colors.bold('\nApplying changes to ZModel...')); + + const deletedModels: string[] = []; + const deletedEnums: string[] = []; + const addedModels: string[] = []; + const addedEnums: string[] = []; + // Hierarchical change tracking: model -> field changes -> attribute changes + type ModelChanges = { + addedFields: string[]; + deletedFields: string[]; + updatedFields: string[]; + addedAttributes: string[]; + deletedAttributes: string[]; + updatedAttributes: string[]; + }; + const modelChanges = new Map(); + + const getModelChanges = (modelName: string): ModelChanges => { + if (!modelChanges.has(modelName)) { + modelChanges.set(modelName, { + addedFields: [], + deletedFields: [], + updatedFields: [], + addedAttributes: [], + deletedAttributes: [], + updatedAttributes: [], + }); + } + return modelChanges.get(modelName)!; + }; + + // Delete models + services.shared.workspace.IndexManager.allElements('DataModel', docsSet) + .filter( + (declaration) => + !newModel.declarations.find((d) => getDbName(d) === getDbName(declaration.node as any)), + ) + .forEach((decl) => { + const model = decl.node!.$container as Model; + const index = model.declarations.findIndex((d) => d === decl.node); + model.declarations.splice(index, 1); + deletedModels.push(colors.red(`- Model ${decl.name} deleted`)); + }); + + // Delete Enums + if (provider.isSupportedFeature('NativeEnum')) + services.shared.workspace.IndexManager.allElements('Enum', docsSet) + .filter( + (declaration) => + !newModel.declarations.find((d) => getDbName(d) === getDbName(declaration.node as any)), + ) + .forEach((decl) => { + const model = decl.node!.$container as Model; + const index = model.declarations.findIndex((d) => d === decl.node); + model.declarations.splice(index, 1); + deletedEnums.push(colors.red(`- Enum ${decl.name} deleted`)); + }); + // Add/update models and their fields + newModel.declarations + .filter((d) => [DataModel, Enum].includes(d.$type)) + .forEach((_declaration) => { + const newDataModel = _declaration as DataModel | Enum; + const declarations = services.shared.workspace.IndexManager.allElements(newDataModel.$type, docsSet).toArray(); + const originalDataModel = declarations.find((d) => getDbName(d.node as any) === getDbName(newDataModel)) + ?.node as DataModel | Enum | undefined; + if (!originalDataModel) { + + if (newDataModel.$type === 'DataModel') { + addedModels.push(colors.green(`+ Model ${newDataModel.name} added`)); + } else if (newDataModel.$type === 'Enum') { + addedEnums.push(colors.green(`+ Enum ${newDataModel.name} added`)); + } + + model.declarations.push(newDataModel); + (newDataModel as any).$container = model; + newDataModel.fields.forEach((f) => { + if (f.$type === 'DataField' && f.type.reference?.ref) { + const ref = declarations.find( + (d) => getDbName(d.node as any) === getDbName(f.type.reference!.ref as any), + )?.node; + if (ref && f.type.reference) { + // Replace the entire reference object — Langium References + // from parsed documents expose `ref` as a getter-only property. + (f.type as any).reference = { + ref, + $refText: (ref as any).name ?? (f.type.reference as any).$refText, + }; + } + } + }); + return; + } + + newDataModel.fields.forEach((f) => { + // Prioritized matching: exact db name > relation fields key > relation FK name > type reference + let originalFields = originalDataModel.fields.filter((d) => getDbName(d) === getDbName(f)); + + // If this is a back-reference relation field (has @relation but no `fields` arg), silently skip + const isRelationField = + f.$type === 'DataField' && !!(f as any).attributes?.some((a: any) => a?.decl?.ref?.name === '@relation'); + if (originalFields.length === 0 && isRelationField && !getRelationFieldsKey(f as any)) { + return; + } + + if (originalFields.length === 0) { + // Try matching by relation fields key (the `fields` attribute in @relation) + // This matches relation fields by their FK field references + const newFieldsKey = getRelationFieldsKey(f as any); + if (newFieldsKey) { + originalFields = originalDataModel.fields.filter( + (d) => getRelationFieldsKey(d as any) === newFieldsKey, + ); + } + } + + if (originalFields.length === 0) { + // Try matching by relation FK name (the `map` attribute in @relation) + originalFields = originalDataModel.fields.filter( + (d) => + getRelationFkName(d as any) === getRelationFkName(f as any) && + !!getRelationFkName(d as any) && + !!getRelationFkName(f as any), + ); + } + + if (originalFields.length === 0) { + // Try matching by type reference + // We need this because for relations that don't have @relation, we can only check if the original exists by the field type. + // Yes, in this case it can potentially result in multiple original fields, but we only want to ensure that at least one relation exists. + // In the future, we might implement some logic to detect how many of these types of relations we need and add/remove fields based on this. + originalFields = originalDataModel.fields.filter( + (d) => + f.$type === 'DataField' && + d.$type === 'DataField' && + f.type.reference?.ref && + d.type.reference?.ref && + getDbName(f.type.reference.ref) === getDbName(d.type.reference.ref), + ); + } + + if (originalFields.length > 1) { + // If this is a back-reference relation field (no `fields` attribute), + // silently skip when there are multiple potential matches + const isBackReferenceField = !getRelationFieldsKey(f as any); + if (!isBackReferenceField) { + console.warn( + colors.yellow( + `Found more original fields, need to tweak the search algorithm. ${originalDataModel.name}->[${originalFields.map((of) => of.name).join(', ')}](${f.name})`, + ), + ); + } + return; + } + const originalField = originalFields.at(0); + + // Update existing field if type, optionality, or array flag changed + if (originalField && f.$type === 'DataField' && originalField.$type === 'DataField') { + const newType = f.type; + const oldType = originalField.type; + const fieldUpdates: string[] = []; + + // Check and update builtin type (e.g., String -> Int) + // Skip if old type is an Enum reference and provider doesn't support native enums + const isOldTypeEnumWithoutNativeSupport = + oldType.reference?.ref?.$type === 'Enum' && !provider.isSupportedFeature('NativeEnum'); + if (newType.type && oldType.type !== newType.type && !isOldTypeEnumWithoutNativeSupport) { + fieldUpdates.push(`type: ${oldType.type} -> ${newType.type}`); + (oldType as any).type = newType.type; + } + + // Check and update type reference (e.g., User -> Profile) + if (newType.reference?.ref && oldType.reference?.ref) { + const newRefName = getDbName(newType.reference.ref); + const oldRefName = getDbName(oldType.reference.ref); + if (newRefName !== oldRefName) { + fieldUpdates.push(`reference: ${oldType.reference.$refText} -> ${newType.reference.$refText}`); + // Replace the entire reference object — Langium References + // from parsed documents expose `ref` as a getter-only property. + (oldType as any).reference = { + ref: newType.reference.ref, + $refText: newType.reference.$refText, + }; + } + } else if (newType.reference?.ref && !oldType.reference) { + // Changed from builtin to reference type + fieldUpdates.push(`type: ${oldType.type} -> ${newType.reference.$refText}`); + (oldType as any).reference = newType.reference; + (oldType as any).type = undefined; + } else if (!newType.reference && oldType.reference?.ref && newType.type) { + // Changed from reference to builtin type + // Skip if old type is an Enum and provider doesn't support native enums (e.g., SQLite stores enums as strings) + const isEnumWithoutNativeSupport = + oldType.reference.ref.$type === 'Enum' && !provider.isSupportedFeature('NativeEnum'); + if (!isEnumWithoutNativeSupport) { + fieldUpdates.push(`type: ${oldType.reference.$refText} -> ${newType.type}`); + (oldType as any).type = newType.type; + (oldType as any).reference = undefined; + } + } + + // Check and update optionality (e.g., String -> String?) + if (!!newType.optional !== !!oldType.optional) { + fieldUpdates.push(`optional: ${!!oldType.optional} -> ${!!newType.optional}`); + (oldType as any).optional = newType.optional; + } + + // Check and update array flag (e.g., String -> String[]) + if (!!newType.array !== !!oldType.array) { + fieldUpdates.push(`array: ${!!oldType.array} -> ${!!newType.array}`); + (oldType as any).array = newType.array; + } + + if (fieldUpdates.length > 0) { + getModelChanges(originalDataModel.name).updatedFields.push( + colors.yellow(`~ ${originalField.name} (${fieldUpdates.join(', ')})`), + ); + } + + // Update @default attribute arguments if changed + const newDefaultAttr = f.attributes.find((a) => a.decl.$refText === '@default'); + const oldDefaultAttr = originalField.attributes.find((a) => a.decl.$refText === '@default'); + if (newDefaultAttr && oldDefaultAttr) { + // Compare attribute arguments by serializing them (avoid circular refs with $type fallback) + const serializeArgs = (args: any[]) => + args.map((arg) => { + if (arg.value?.$type === 'StringLiteral') return `"${arg.value.value}"`; + if (arg.value?.$type === 'NumberLiteral') return String(arg.value.value); + if (arg.value?.$type === 'BooleanLiteral') return String(arg.value.value); + if (arg.value?.$type === 'InvocationExpr') return arg.value.function?.$refText ?? ''; + if (arg.value?.$type === 'ReferenceExpr') return arg.value.target?.$refText ?? ''; + if (arg.value?.$type === 'ArrayExpr') { + return `[${(arg.value.items ?? []).map((item: any) => { + if (item.$type === 'ReferenceExpr') return item.target?.$refText ?? ''; + return item.$type ?? 'unknown'; + }).join(',')}]`; + } + // Fallback: use $type to avoid circular reference issues + return arg.value?.$type ?? 'unknown'; + }).join(','); + + const newArgsStr = serializeArgs(newDefaultAttr.args ?? []); + const oldArgsStr = serializeArgs(oldDefaultAttr.args ?? []); + + if (newArgsStr !== oldArgsStr) { + // Replace old @default arguments with new ones + (oldDefaultAttr as any).args = newDefaultAttr.args.map((arg) => ({ + ...arg, + $container: oldDefaultAttr, + })); + getModelChanges(originalDataModel.name).updatedAttributes.push( + colors.yellow(`~ @default on ${originalDataModel.name}.${originalField.name}`), + ); + } + } + } + + if (!originalField) { + getModelChanges(originalDataModel.name).addedFields.push(colors.green(`+ ${f.name}`)); + (f as any).$container = originalDataModel; + originalDataModel.fields.push(f as any); + if (f.$type === 'DataField' && f.type.reference?.ref) { + const ref = declarations.find( + (d) => getDbName(d.node as any) === getDbName(f.type.reference!.ref as any), + )?.node as DataModel | undefined; + if (ref) { + // Replace the entire reference object — Langium References + // from parsed documents expose `ref` as a getter-only property. + (f.type as any).reference = { + ref, + $refText: ref.name ?? (f.type.reference as any).$refText, + }; + } + } + return; + } + + // Track deleted attributes (in original but not in new) + originalField.attributes + .filter( + (attr) => + !f.attributes.find((d) => d.decl.$refText === attr.decl.$refText) && + isDatabaseManagedAttribute(attr.decl.$refText), + ) + .forEach((attr) => { + const field = attr.$container; + const index = field.attributes.findIndex((d) => d === attr); + field.attributes.splice(index, 1); + getModelChanges(originalDataModel.name).deletedAttributes.push( + colors.yellow(`- ${attr.decl.$refText} from field: ${originalDataModel.name}.${field.name}`), + ); + }); + + // Track added attributes (in new but not in original) + f.attributes + .filter( + (attr) => + !originalField.attributes.find((d) => d.decl.$refText === attr.decl.$refText) && + isDatabaseManagedAttribute(attr.decl.$refText), + ) + .forEach((attr) => { + // attach the new attribute to the original field + const cloned = { ...attr, $container: originalField } as typeof attr; + originalField.attributes.push(cloned); + getModelChanges(originalDataModel.name).addedAttributes.push( + colors.green(`+ ${attr.decl.$refText} to field: ${originalDataModel.name}.${f.name}`), + ); + }); + }); + originalDataModel.fields + .filter((f) => { + // Prioritized matching: exact db name > relation fields key > relation FK name > type reference + const matchByDbName = newDataModel.fields.find((d) => getDbName(d) === getDbName(f)); + if (matchByDbName) return false; + + // Try matching by relation fields key (the `fields` attribute in @relation) + const originalFieldsKey = getRelationFieldsKey(f as any); + if (originalFieldsKey) { + const matchByFieldsKey = newDataModel.fields.find( + (d) => getRelationFieldsKey(d as any) === originalFieldsKey, + ); + if (matchByFieldsKey) return false; + } + + const matchByFkName = newDataModel.fields.find( + (d) => + getRelationFkName(d as any) === getRelationFkName(f as any) && + !!getRelationFkName(d as any) && + !!getRelationFkName(f as any), + ); + if (matchByFkName) return false; + + const matchByTypeRef = newDataModel.fields.find( + (d) => + f.$type === 'DataField' && + d.$type === 'DataField' && + f.type.reference?.ref && + d.type.reference?.ref && + getDbName(f.type.reference.ref) === getDbName(d.type.reference.ref), + ); + return !matchByTypeRef; + }) + .forEach((f) => { + const _model = f.$container; + const index = _model.fields.findIndex((d) => d === f); + _model.fields.splice(index, 1); + getModelChanges(_model.name).deletedFields.push(colors.red(`- ${f.name}`)); + }); + }); + + if (deletedModels.length > 0) { + console.log(colors.bold('\nDeleted Models:')); + deletedModels.forEach((msg) => { + console.log(msg); + }); + } + + if (deletedEnums.length > 0) { + console.log(colors.bold('\nDeleted Enums:')); + deletedEnums.forEach((msg) => { + console.log(msg); + }); + } + + if (addedModels.length > 0) { + console.log(colors.bold('\nAdded Models:')); + addedModels.forEach((msg) => { + console.log(msg); + }); + } + + if (addedEnums.length > 0) { + console.log(colors.bold('\nAdded Enums:')); + addedEnums.forEach((msg) => { + console.log(msg); + }); + } + + // Print hierarchical model changes + if (modelChanges.size > 0) { + console.log(colors.bold('\nModel Changes:')); + modelChanges.forEach((changes, modelName) => { + const hasChanges = + changes.addedFields.length > 0 || + changes.deletedFields.length > 0 || + changes.updatedFields.length > 0 || + changes.addedAttributes.length > 0 || + changes.deletedAttributes.length > 0 || + changes.updatedAttributes.length > 0; + + if (hasChanges) { + console.log(colors.cyan(` ${modelName}:`)); + + if (changes.addedFields.length > 0) { + console.log(colors.gray(' Added Fields:')); + changes.addedFields.forEach((msg) => { + console.log(` ${msg}`); + }); + } + + if (changes.deletedFields.length > 0) { + console.log(colors.gray(' Deleted Fields:')); + changes.deletedFields.forEach((msg) => { + console.log(` ${msg}`); + }); + } + + if (changes.updatedFields.length > 0) { + console.log(colors.gray(' Updated Fields:')); + changes.updatedFields.forEach((msg) => { + console.log(` ${msg}`); + }); + } + + if (changes.addedAttributes.length > 0) { + console.log(colors.gray(' Added Attributes:')); + changes.addedAttributes.forEach((msg) => { + console.log(` ${msg}`); + }); + } + + if (changes.deletedAttributes.length > 0) { + console.log(colors.gray(' Deleted Attributes:')); + changes.deletedAttributes.forEach((msg) => { + console.log(` ${msg}`); + }); + } + + if (changes.updatedAttributes.length > 0) { + console.log(colors.gray(' Updated Attributes:')); + changes.updatedAttributes.forEach((msg) => { + console.log(` ${msg}`); + }); + } + } + }); + } + + const generator = new ZModelCodeGenerator({ + quote: options.quote ?? 'single', + indent: options.indent ?? 4, + }); + + if (options.output) { + if (treatAsFile) { + const zmodelSchema = await formatDocument(generator.generate(newModel)); + console.log(colors.blue(`Writing to ${outPath}`)); + fs.mkdirSync(path.dirname(outPath), { recursive: true }); + fs.writeFileSync(outPath, zmodelSchema); + } else { + // Otherwise treat `--out` as a directory path. Create it if needed. + fs.mkdirSync(outPath!, { recursive: true }); + + // Preserve the directory structure relative to the schema file location (options.schema base). + const baseDir = path.dirname(path.resolve(schemaFile)); + + for (const { + uri, + parseResult: { value: documentModel }, + } of docs) { + const zmodelSchema = await formatDocument(generator.generate(documentModel)); + + // Map input file path -> output file path under `--out` + const relPath = path.relative(baseDir, uri.fsPath); + const targetFile = path.join(outPath!, relPath); + + fs.mkdirSync(path.dirname(targetFile), { recursive: true }); + console.log(colors.blue(`Writing to ${targetFile}`)); + fs.writeFileSync(targetFile, zmodelSchema); + } + } + } else { + for (const { + uri, + parseResult: { value: documentModel }, + } of docs) { + const zmodelSchema = await formatDocument(generator.generate(documentModel)); + console.log(colors.blue(`Writing to ${path.relative(process.cwd(), uri.fsPath).replace(/\\/g, '/')}`)); + fs.writeFileSync(uri.fsPath, zmodelSchema); + } + } + + console.log(colors.green.bold('\nPull completed successfully!')); + } catch (error) { + spinner.fail('Pull failed'); + console.error(error); + throw error; + } +} diff --git a/packages/cli/src/actions/generate.ts b/packages/cli/src/actions/generate.ts index 7ac6db6b2..c014c02ef 100644 --- a/packages/cli/src/actions/generate.ts +++ b/packages/cli/src/actions/generate.ts @@ -12,7 +12,8 @@ import { watch } from 'chokidar'; import ora, { type Ora } from 'ora'; import { CliError } from '../cli-error'; import * as corePlugins from '../plugins'; -import { getOutputPath, getSchemaFile, loadSchemaDocument } from './action-utils'; +import { getOutputPath, getSchemaFile, getZenStackPackages, loadSchemaDocument } from './action-utils'; +import semver from 'semver'; type Options = { schema?: string; @@ -27,6 +28,11 @@ type Options = { * CLI action for generating code from schema */ export async function run(options: Options) { + try { + await checkForMismatchedPackages(process.cwd()); + } catch (err) { + console.warn(colors.yellow(`Failed to check for mismatched ZenStack packages: ${err}`)); + } const model = await pureGenerate(options, false); if (options.watch) { @@ -315,3 +321,40 @@ async function loadPluginModule(provider: string, basePath: string) { return undefined; } } + +async function checkForMismatchedPackages(projectPath: string) { + const packages = await getZenStackPackages(projectPath); + if (!packages.length) { + return false; + } + + const versions = new Set(); + for (const { version } of packages) { + if (version) { + versions.add(version); + } + } + + if (versions.size > 1) { + const message = + 'WARNING: Multiple versions of ZenStack packages detected.\n\tThis will probably cause issues and break your types.'; + const slashes = '/'.repeat(73); + const latestVersion = semver.sort(Array.from(versions)).reverse()[0]!; + + console.warn(colors.yellow(`${slashes}\n\n\t${message}\n`)); + for (const { pkg, version } of packages) { + if (!version) continue; + + if (version === latestVersion) { + console.log(`\t${pkg.padEnd(32)}\t${colors.green(version)}`); + } else { + console.log(`\t${pkg.padEnd(32)}\t${colors.yellow(version)}`); + } + } + console.warn(`\n${colors.yellow(slashes)}`); + + return true; + } + + return false; +} diff --git a/packages/cli/src/actions/info.ts b/packages/cli/src/actions/info.ts index bbea51ebb..26e42422f 100644 --- a/packages/cli/src/actions/info.ts +++ b/packages/cli/src/actions/info.ts @@ -1,12 +1,12 @@ import colors from 'colors'; -import path from 'node:path'; +import { getZenStackPackages } from './action-utils'; /** * CLI action for getting information about installed ZenStack packages */ export async function run(projectPath: string) { const packages = await getZenStackPackages(projectPath); - if (!packages) { + if (!packages.length) { console.error('Unable to locate package.json. Are you in a valid project directory?'); return; } @@ -24,48 +24,3 @@ export async function run(projectPath: string) { console.warn(colors.yellow('WARNING: Multiple versions of Zenstack packages detected. This may cause issues.')); } } - -async function getZenStackPackages(projectPath: string): Promise> { - let pkgJson: { - dependencies: Record; - devDependencies: Record; - }; - const resolvedPath = path.resolve(projectPath); - try { - pkgJson = ( - await import(path.join(resolvedPath, 'package.json'), { - with: { type: 'json' }, - }) - ).default; - } catch { - return []; - } - - const packages = Array.from( - new Set( - [...Object.keys(pkgJson.dependencies ?? {}), ...Object.keys(pkgJson.devDependencies ?? {})].filter( - (p) => p.startsWith('@zenstackhq/') || p === 'zenstack', - ), - ), - ).sort(); - - const result = await Promise.all( - packages.map(async (pkg) => { - try { - const depPkgJson = ( - await import(`${pkg}/package.json`, { - with: { type: 'json' }, - }) - ).default; - if (depPkgJson.private) { - return undefined; - } - return { pkg, version: depPkgJson.version as string }; - } catch { - return { pkg, version: undefined }; - } - }), - ); - - return result.filter((p) => !!p); -} diff --git a/packages/cli/src/actions/pull/casing.ts b/packages/cli/src/actions/pull/casing.ts new file mode 100644 index 000000000..5e0846bc9 --- /dev/null +++ b/packages/cli/src/actions/pull/casing.ts @@ -0,0 +1,43 @@ +export function resolveNameCasing(casing: 'pascal' | 'camel' | 'snake' | 'none', originalName: string) { + let name = originalName; + const fieldPrefix = /[0-9]/g.test(name.charAt(0)) ? '_' : ''; + + switch (casing) { + case 'pascal': + name = toPascalCase(originalName); + break; + case 'camel': + name = toCamelCase(originalName); + break; + case 'snake': + name = toSnakeCase(originalName); + break; + } + + return { + modified: name !== originalName || fieldPrefix !== '', + name: `${fieldPrefix}${name}`, + }; +} + +function isAllUpperCase(str: string): boolean { + return str === str.toUpperCase(); +} + +export function toPascalCase(str: string): string { + if (isAllUpperCase(str)) return str; + return str.replace(/[_\- ]+(\w)/g, (_, c) => c.toUpperCase()).replace(/^\w/, (c) => c.toUpperCase()); +} + +export function toCamelCase(str: string): string { + if (isAllUpperCase(str)) return str; + return str.replace(/[_\- ]+(\w)/g, (_, c) => c.toUpperCase()).replace(/^\w/, (c) => c.toLowerCase()); +} + +export function toSnakeCase(str: string): string { + if (isAllUpperCase(str)) return str; + return str + .replace(/[- ]+/g, '_') + .replace(/([a-z0-9])([A-Z])/g, '$1_$2') + .toLowerCase(); +} diff --git a/packages/cli/src/actions/pull/index.ts b/packages/cli/src/actions/pull/index.ts new file mode 100644 index 000000000..998852cb6 --- /dev/null +++ b/packages/cli/src/actions/pull/index.ts @@ -0,0 +1,678 @@ +import type { ZModelServices } from '@zenstackhq/language'; +import colors from 'colors'; +import { + isEnum, + type DataField, + type DataModel, + type Enum, + type Model, +} from '@zenstackhq/language/ast'; +import { + DataFieldAttributeFactory, + DataFieldFactory, + DataModelFactory, + EnumFactory, +} from '@zenstackhq/language/factory'; +import { AstUtils, type Reference, type AstNode, type CstNode } from 'langium'; +import { lowerCaseFirst } from '@zenstackhq/common-helpers'; +import type { PullOptions } from '../db'; +import type { Cascade, IntrospectedEnum, IntrospectedTable, IntrospectionProvider } from './provider'; +import { getAttributeRef, getDbName, getEnumRef } from './utils'; +import { resolveNameCasing } from './casing'; +import { CliError } from '../../cli-error'; + +export function syncEnums({ + dbEnums, + model, + oldModel, + provider, + options, + services, + defaultSchema, +}: { + dbEnums: IntrospectedEnum[]; + model: Model; + oldModel: Model; + provider: IntrospectionProvider; + services: ZModelServices; + options: PullOptions; + defaultSchema: string; +}) { + if (provider.isSupportedFeature('NativeEnum')) { + for (const dbEnum of dbEnums) { + const { modified, name } = resolveNameCasing(options.modelCasing, dbEnum.enum_type); + if (modified) console.log(colors.gray(`Mapping enum ${dbEnum.enum_type} to ${name}`)); + const factory = new EnumFactory().setName(name); + if (modified || options.alwaysMap) + factory.addAttribute((builder) => + builder + .setDecl(getAttributeRef('@@map', services)) + .addArg((argBuilder) => argBuilder.StringLiteral.setValue(dbEnum.enum_type)), + ); + + dbEnum.values.forEach((v) => { + const { name, modified } = resolveNameCasing(options.fieldCasing, v); + factory.addField((builder) => { + builder.setName(name); + if (modified || options.alwaysMap) + builder.addAttribute((builder) => + builder + .setDecl(getAttributeRef('@map', services)) + .addArg((argBuilder) => argBuilder.StringLiteral.setValue(v)), + ); + + return builder; + }); + }); + + if (dbEnum.schema_name && dbEnum.schema_name !== '' && dbEnum.schema_name !== defaultSchema) { + factory.addAttribute((b) => + b + .setDecl(getAttributeRef('@@schema', services)) + .addArg((a) => a.StringLiteral.setValue(dbEnum.schema_name)), + ); + } + + model.declarations.push(factory.get({ $container: model })); + } + } else { + // For providers that don't support native enums (e.g., SQLite), carry over + // enum declarations from the existing schema as-is by deep-cloning the AST nodes. + // A dummy buildReference is used since we don't need cross-reference resolution. + const dummyBuildReference = (_node: AstNode, _property: string, _refNode: CstNode | undefined, refText: string): Reference => + ({ $refText: refText }) as Reference; + + oldModel.declarations + .filter((d) => isEnum(d)) + .forEach((d) => { + const copy = AstUtils.copyAstNode(d, dummyBuildReference); + (copy as { $container: unknown }).$container = model; + model.declarations.push(copy); + }); + } +} + +export type Relation = { + schema: string; + table: string; + columns: string[]; + type: 'one' | 'many'; + fk_name: string; + foreign_key_on_update: Cascade; + foreign_key_on_delete: Cascade; + nullable: boolean; + references: { + schema: string | null; + table: string | null; + columns: (string | null)[]; + type: 'one' | 'many'; + }; +}; + +export function syncTable({ + model, + provider, + table, + services, + options, + defaultSchema, +}: { + table: IntrospectedTable; + model: Model; + oldModel: Model; + provider: IntrospectionProvider; + services: ZModelServices; + options: PullOptions; + defaultSchema: string; +}) { + const idAttribute = getAttributeRef('@id', services); + const modelIdAttribute = getAttributeRef('@@id', services); + const uniqueAttribute = getAttributeRef('@unique', services); + const modelUniqueAttribute = getAttributeRef('@@unique', services); + const fieldMapAttribute = getAttributeRef('@map', services); + const tableMapAttribute = getAttributeRef('@@map', services); + const modelindexAttribute = getAttributeRef('@@index', services); + + const relations: Relation[] = []; + const { name, modified } = resolveNameCasing(options.modelCasing, table.name); + const multiPk = table.columns.filter((c) => c.pk).length > 1; + + const modelFactory = new DataModelFactory().setName(name).setIsView(table.type === 'view'); + modelFactory.setContainer(model); + + if (modified || options.alwaysMap) { + modelFactory.addAttribute((builder) => + builder.setDecl(tableMapAttribute).addArg((argBuilder) => argBuilder.StringLiteral.setValue(table.name)), + ); + } + // Group FK columns by constraint name to handle composite foreign keys. + // Each FK constraint (identified by fk_name) may span multiple columns. + const fkGroups = new Map(); + table.columns.forEach((column) => { + if (column.foreign_key_table && column.foreign_key_name) { + const group = fkGroups.get(column.foreign_key_name) ?? []; + group.push(column); + fkGroups.set(column.foreign_key_name, group); + } + }); + + for (const [fkName, fkColumns] of fkGroups) { + const firstCol = fkColumns[0]!; + // For single-column FKs, check if the column is the table's single-column PK (one-to-one) + const isSingleColumnPk = fkColumns.length === 1 && !multiPk && firstCol.pk; + // A single-column FK with unique constraint means one-to-one on the opposite side + const isUniqueRelation = (fkColumns.length === 1 && firstCol.unique) || isSingleColumnPk; + relations.push({ + schema: table.schema, + table: table.name, + columns: fkColumns.map((c) => c.name), + type: 'one', + fk_name: fkName, + foreign_key_on_delete: firstCol.foreign_key_on_delete, + foreign_key_on_update: firstCol.foreign_key_on_update, + nullable: firstCol.nullable, + references: { + schema: firstCol.foreign_key_schema, + table: firstCol.foreign_key_table, + columns: fkColumns.map((c) => c.foreign_key_column), + type: isUniqueRelation ? 'one' : 'many', + }, + }); + } + + table.columns.forEach((column) => { + + const { name, modified } = resolveNameCasing(options.fieldCasing, column.name); + + const builtinType = provider.getBuiltinType(column.datatype); + + modelFactory.addField((builder) => { + builder.setName(name); + builder.setType((typeBuilder) => { + typeBuilder.setArray(builtinType.isArray); + // Array fields cannot be optional (Prisma/ZenStack limitation) + typeBuilder.setOptional(builtinType.isArray ? false : column.nullable); + + if (column.computed) { + // Generated/computed columns (e.g., GENERATED ALWAYS AS ... STORED/VIRTUAL) + // are read-only and must be rendered as Unsupported("full type definition"). + // The datatype contains the full DDL type definition including the expression. + typeBuilder.setUnsupported((unsupportedBuilder) => + unsupportedBuilder.setValue((lt) => lt.StringLiteral.setValue(column.datatype)), + ); + } else if (column.datatype === 'enum') { + const ref = model.declarations.find((d) => isEnum(d) && getDbName(d) === column.datatype_name) as + | Enum + | undefined; + + if (!ref) { + throw new CliError(`Enum ${column.datatype_name} not found`); + } + typeBuilder.setReference(ref); + } else { + if (builtinType.type !== 'Unsupported') { + typeBuilder.setType(builtinType.type); + } else { + typeBuilder.setUnsupported((unsupportedBuilder) => + unsupportedBuilder.setValue((lt) => lt.StringLiteral.setValue(column.datatype)), + ); + } + } + + return typeBuilder; + }); + + if (column.pk && !multiPk) { + builder.addAttribute((b) => b.setDecl(idAttribute)); + } + + // Add field-type-based attributes (e.g., @updatedAt for DateTime fields, @db.* attributes) + const fieldAttrs = provider.getFieldAttributes({ + fieldName: column.name, + fieldType: builtinType.type, + datatype: column.datatype, + length: column.length, + precision: column.precision, + services, + }); + fieldAttrs.forEach(builder.addAttribute.bind(builder)); + + if (column.default && !column.computed) { + const defaultExprBuilder = provider.getDefaultValue({ + fieldType: builtinType.type, + datatype: column.datatype, + datatype_name: column.datatype_name, + defaultValue: column.default, + services, + enums: model.declarations.filter((d) => d.$type === 'Enum') as Enum[], + }); + if (defaultExprBuilder) { + const defaultAttr = new DataFieldAttributeFactory() + .setDecl(getAttributeRef('@default', services)) + .addArg(defaultExprBuilder); + builder.addAttribute(defaultAttr); + } + } + + if (column.unique && !column.pk) { + builder.addAttribute((b) => { + b.setDecl(uniqueAttribute); + // Only add map if the unique constraint name differs from default patterns + // Default patterns: TableName_columnName_key (Prisma) or just columnName (MySQL) + const isDefaultName = !column.unique_name + || column.unique_name === `${table.name}_${column.name}_key` + || column.unique_name === column.name; + if (!isDefaultName) { + b.addArg((ab) => ab.StringLiteral.setValue(column.unique_name!), 'map'); + } + + return b; + }); + } + if (modified || options.alwaysMap) { + builder.addAttribute((ab) => + ab.setDecl(fieldMapAttribute).addArg((ab) => ab.StringLiteral.setValue(column.name)), + ); + } + + return builder; + }); + }); + + const pkColumns = table.columns.filter((c) => c.pk).map((c) => c.name); + if (multiPk) { + modelFactory.addAttribute((builder) => + builder.setDecl(modelIdAttribute).addArg((argBuilder) => { + const arrayExpr = argBuilder.ArrayExpr; + pkColumns.forEach((c) => { + const ref = modelFactory.node.fields.find((f) => getDbName(f) === c); + if (!ref) { + throw new CliError(`Field ${c} not found`); + } + arrayExpr.addItem((itemBuilder) => itemBuilder.ReferenceExpr.setTarget(ref)); + }); + return arrayExpr; + }), + ); + } + + const hasUniqueConstraint = + table.columns.some((c) => c.unique || c.pk) || + table.indexes.some((i) => i.unique); + if (!hasUniqueConstraint) { + modelFactory.addAttribute((a) => a.setDecl(getAttributeRef('@@ignore', services))); + modelFactory.addComment( + '/// The underlying table does not contain a valid unique identifier and can therefore currently not be handled by Zenstack Client.', + ); + } + + // Sort indexes: unique indexes first, then other indexes + const sortedIndexes = table.indexes.reverse().sort((a, b) => { + if (a.unique && !b.unique) return -1; + if (!a.unique && b.unique) return 1; + return 0; + }); + + sortedIndexes.forEach((index) => { + if (index.predicate) { + //These constraints are not supported by Zenstack, because Zenstack currently does not fully support check constraints. Read more: https://pris.ly/d/check-constraints + console.warn( + colors.yellow( + `These constraints are not supported by Zenstack. Read more: https://pris.ly/d/check-constraints\n- Model: "${table.name}", constraint: "${index.name}"`, + ), + ); + return; + } + if (index.columns.find((c) => c.expression)) { + console.warn( + colors.yellow( + `These constraints are not supported by Zenstack. Read more: https://pris.ly/d/check-constraints\n- Model: "${table.name}", constraint: "${index.name}"`, + ), + ); + return; + } + + // Skip PRIMARY key index (handled via @id or @@id) + if (index.primary) { + return; + } + + // Skip single-column indexes that are already handled by @id or @unique on the field + if (index.columns.length === 1 && (index.columns.find((c) => pkColumns.includes(c.name)) || index.unique)) { + return; + } + + modelFactory.addAttribute((builder) => + { + const attr = builder + .setDecl(index.unique ? modelUniqueAttribute : modelindexAttribute) + .addArg((argBuilder) => { + const arrayExpr = argBuilder.ArrayExpr; + index.columns.forEach((c) => { + const ref = modelFactory.node.fields.find((f) => getDbName(f) === c.name); + if (!ref) { + throw new CliError(`Column ${c.name} not found in model ${table.name}`); + } + arrayExpr.addItem((itemBuilder) => { + const refExpr = itemBuilder.ReferenceExpr.setTarget(ref); + if (c.order && c.order !== 'ASC') + refExpr.addArg((ab) => ab.StringLiteral.setValue('DESC'), 'sort'); + + return refExpr; + }); + }); + return arrayExpr; + }); + + const suffix = index.unique ? '_key' : '_idx'; + + if(index.name !== `${table.name}_${index.columns.map(c => c.name).join('_')}${suffix}`){ + attr.addArg((argBuilder) => argBuilder.StringLiteral.setValue(index.name), 'map'); + } + + return attr + } + + ); + }); + if (table.schema && table.schema !== '' && table.schema !== defaultSchema) { + modelFactory.addAttribute((b) => + b.setDecl(getAttributeRef('@@schema', services)).addArg((a) => a.StringLiteral.setValue(table.schema)), + ); + } + + model.declarations.push(modelFactory.node); + return relations; +} + +export function syncRelation({ + model, + relation, + services, + options, + selfRelation, + similarRelations, +}: { + model: Model; + relation: Relation; + services: ZModelServices; + options: PullOptions; + //self included + similarRelations: number; + selfRelation: boolean; +}) { + const idAttribute = getAttributeRef('@id', services); + const uniqueAttribute = getAttributeRef('@unique', services); + const relationAttribute = getAttributeRef('@relation', services); + const fieldMapAttribute = getAttributeRef('@map', services); + const tableMapAttribute = getAttributeRef('@@map', services); + + const includeRelationName = selfRelation || similarRelations > 0; + + if (!idAttribute || !uniqueAttribute || !relationAttribute || !fieldMapAttribute || !tableMapAttribute) { + throw new CliError('Cannot find required attributes in the model.'); + } + + const sourceModel = model.declarations.find((d) => d.$type === 'DataModel' && getDbName(d) === relation.table) as + | DataModel + | undefined; + if (!sourceModel) return; + + // Resolve all source and target fields for the relation (supports composite FKs) + const sourceFields: { field: DataField; index: number }[] = []; + for (const colName of relation.columns) { + const idx = sourceModel.fields.findIndex((f) => getDbName(f) === colName); + const field = sourceModel.fields[idx] as DataField | undefined; + if (!field) return; + sourceFields.push({ field, index: idx }); + } + + const targetModel = model.declarations.find( + (d) => d.$type === 'DataModel' && getDbName(d) === relation.references.table, + ) as DataModel | undefined; + if (!targetModel) return; + + const targetFields: DataField[] = []; + for (const colName of relation.references.columns) { + const field = targetModel.fields.find((f) => getDbName(f) === colName); + if (!field) return; + targetFields.push(field); + } + + // Use the first source field for naming heuristics + const firstSourceField = sourceFields[0]!.field; + const firstSourceFieldId = sourceFields[0]!.index; + const firstColumn = relation.columns[0]!; + + const fieldPrefix = /[0-9]/g.test(sourceModel.name.charAt(0)) ? '_' : ''; + + const relationName = `${relation.table}${similarRelations > 0 ? `_${firstColumn}` : ''}To${relation.references.table}`; + + // Derive a relation field name from the FK scalar field: if the field ends with "Id", + // strip the suffix and use the remainder (e.g., "authorId" -> "author"). + const sourceNameFromReference = firstSourceField.name.toLowerCase().endsWith('id') ? `${resolveNameCasing(options.fieldCasing, firstSourceField.name.slice(0, -2)).name}${relation.type === 'many'? 's' : ''}` : undefined; + + // Check if the derived name would clash with an existing field + const sourceFieldFromReference = sourceModel.fields.find((f) => f.name === sourceNameFromReference); + + // Determine the relation field name: + // - For ambiguous relations (multiple FKs to the same table), include the source column for disambiguation. + // - Otherwise, prefer the name derived from the FK field (if no clash), falling back to the target model name. + let { name: sourceFieldName } = resolveNameCasing( + options.fieldCasing, + similarRelations > 0 + ? `${fieldPrefix}${lowerCaseFirst(sourceModel.name)}_${firstColumn}` + : `${(!sourceFieldFromReference? sourceNameFromReference : undefined) || lowerCaseFirst(resolveNameCasing(options.fieldCasing, targetModel.name).name)}${relation.type === 'many'? 's' : ''}`, + ); + + if (sourceModel.fields.find((f) => f.name === sourceFieldName)) { + sourceFieldName = `${sourceFieldName}To${lowerCaseFirst(targetModel.name)}_${relation.references.columns[0]}`; + } + + const sourceFieldFactory = new DataFieldFactory() + .setContainer(sourceModel) + .setName(sourceFieldName) + .setType((tb) => + tb + .setOptional(relation.nullable) + .setArray(relation.type === 'many') + .setReference(targetModel), + ); + sourceFieldFactory.addAttribute((ab) => { + ab.setDecl(relationAttribute); + if (includeRelationName) ab.addArg((ab) => ab.StringLiteral.setValue(relationName)); + + // Build fields array (all source FK columns) + ab.addArg((ab) => { + const arrayExpr = ab.ArrayExpr; + for (const { field } of sourceFields) { + arrayExpr.addItem((aeb) => aeb.ReferenceExpr.setTarget(field)); + } + return arrayExpr; + }, 'fields'); + + // Build references array (all target columns) + ab.addArg((ab) => { + const arrayExpr = ab.ArrayExpr; + for (const field of targetFields) { + arrayExpr.addItem((aeb) => aeb.ReferenceExpr.setTarget(field)); + } + return arrayExpr; + }, 'references'); + + // Prisma defaults: onDelete is SetNull for optional, Restrict for mandatory + const onDeleteDefault = relation.nullable ? 'SET NULL' : 'RESTRICT'; + if (relation.foreign_key_on_delete && relation.foreign_key_on_delete !== onDeleteDefault) { + const enumRef = getEnumRef('ReferentialAction', services); + if (!enumRef) throw new CliError('ReferentialAction enum not found'); + const enumFieldRef = enumRef.fields.find( + (f) => f.name.toLowerCase() === relation.foreign_key_on_delete!.replace(/ /g, '').toLowerCase(), + ); + if (!enumFieldRef) throw new CliError(`ReferentialAction ${relation.foreign_key_on_delete} not found`); + ab.addArg((a) => a.ReferenceExpr.setTarget(enumFieldRef), 'onDelete'); + } + + // Prisma default: onUpdate is Cascade + if (relation.foreign_key_on_update && relation.foreign_key_on_update !== 'CASCADE') { + const enumRef = getEnumRef('ReferentialAction', services); + if (!enumRef) throw new CliError('ReferentialAction enum not found'); + const enumFieldRef = enumRef.fields.find( + (f) => f.name.toLowerCase() === relation.foreign_key_on_update!.replace(/ /g, '').toLowerCase(), + ); + if (!enumFieldRef) throw new CliError(`ReferentialAction ${relation.foreign_key_on_update} not found`); + ab.addArg((a) => a.ReferenceExpr.setTarget(enumFieldRef), 'onUpdate'); + } + + // Check if the FK constraint name differs from the default pattern + const defaultFkName = `${relation.table}_${relation.columns.join('_')}_fkey`; + if (relation.fk_name && relation.fk_name !== defaultFkName) ab.addArg((ab) => ab.StringLiteral.setValue(relation.fk_name), 'map'); + + return ab; + }); + + sourceModel.fields.splice(firstSourceFieldId, 0, sourceFieldFactory.node); // Insert the relation field before the first FK scalar field + + const oppositeFieldPrefix = /[0-9]/g.test(targetModel.name.charAt(0)) ? '_' : ''; + let { name: oppositeFieldName } = resolveNameCasing( + options.fieldCasing, + similarRelations > 0 + ? `${oppositeFieldPrefix}${lowerCaseFirst(sourceModel.name)}_${firstColumn}` + : `${lowerCaseFirst(resolveNameCasing(options.fieldCasing, sourceModel.name).name)}${relation.references.type === 'many'? 's' : ''}`, + ); + + if (targetModel.fields.find((f) => f.name === oppositeFieldName)) { + ({ name: oppositeFieldName } = resolveNameCasing( + options.fieldCasing, + `${lowerCaseFirst(sourceModel.name)}_${firstColumn}To${relation.references.table}_${relation.references.columns[0]}`, + )); + } + + const targetFieldFactory = new DataFieldFactory() + .setContainer(targetModel) + .setName(oppositeFieldName) + .setType((tb) => + tb + .setOptional(relation.references.type === 'one') + .setArray(relation.references.type === 'many') + .setReference(sourceModel), + ); + if (includeRelationName) + targetFieldFactory.addAttribute((ab) => + ab.setDecl(relationAttribute).addArg((ab) => ab.StringLiteral.setValue(relationName)), + ); + + targetModel.fields.push(targetFieldFactory.node); +} + +/** + * Consolidates per-column enums back to shared enums when possible. + * + * MySQL doesn't have named enum types — each column gets a synthetic enum + * (e.g., `UserStatus`, `GroupStatus`). When the original schema used a shared + * enum (e.g., `Status`) across multiple fields, this function detects the + * mapping via field references and consolidates the synthetic enums back into + * the original shared enum so the merge phase can match them correctly. + */ +export function consolidateEnums({ + newModel, + oldModel, +}: { + newModel: Model; + oldModel: Model; +}) { + const newEnums = newModel.declarations.filter((d) => isEnum(d)) as Enum[]; + const newDataModels = newModel.declarations.filter((d) => d.$type === 'DataModel') as DataModel[]; + const oldDataModels = oldModel.declarations.filter((d) => d.$type === 'DataModel') as DataModel[]; + + // For each new enum, find which old enum it corresponds to (via field references) + const enumMapping = new Map(); // newEnum -> oldEnum + + for (const newEnum of newEnums) { + for (const newDM of newDataModels) { + for (const field of newDM.fields) { + if (field.$type !== 'DataField' || field.type.reference?.ref !== newEnum) continue; + + // Find matching model in old model by db name + const oldDM = oldDataModels.find((d) => getDbName(d) === getDbName(newDM)); + if (!oldDM) continue; + + // Find matching field in old model by db name + const oldField = oldDM.fields.find((f) => getDbName(f) === getDbName(field)); + if (!oldField || oldField.$type !== 'DataField' || !oldField.type.reference?.ref) continue; + + const oldEnum = oldField.type.reference.ref; + if (!isEnum(oldEnum)) continue; + + enumMapping.set(newEnum, oldEnum as Enum); + break; + } + if (enumMapping.has(newEnum)) break; + } + } + + // Group by old enum: oldEnum -> [newEnum1, newEnum2, ...] + const reverseMapping = new Map(); + for (const [newEnum, oldEnum] of enumMapping) { + if (!reverseMapping.has(oldEnum)) { + reverseMapping.set(oldEnum, []); + } + reverseMapping.get(oldEnum)!.push(newEnum); + } + + // Consolidate: when new enums map to the same old enum with matching values + for (const [oldEnum, newEnumsGroup] of reverseMapping) { + const keepEnum = newEnumsGroup[0]!; + + // Skip if already correct (single enum with matching name) + if (newEnumsGroup.length === 1 && keepEnum.name === oldEnum.name) continue; + + // Check that all new enums have the same values as the old enum + const oldValues = new Set(oldEnum.fields.map((f) => getDbName(f))); + const allMatch = newEnumsGroup.every((ne) => { + const newValues = new Set(ne.fields.map((f) => getDbName(f))); + return oldValues.size === newValues.size && [...oldValues].every((v) => newValues.has(v)); + }); + + if (!allMatch) continue; + + // Rename the kept enum to match the old shared name + keepEnum.name = oldEnum.name; + + // Replace keepEnum's attributes with those from the old enum so that + // any synthetic @@map added by syncEnums is removed and getDbName(keepEnum) + // reflects the consolidated name rather than the stale per-column name. + // Shallow-copy and re-parent so AST $container pointers reference keepEnum. + keepEnum.attributes = oldEnum.attributes.map((attr) => { + const copy = { ...attr, $container: keepEnum }; + return copy; + }); + + // Remove duplicate enums from newModel + for (let i = 1; i < newEnumsGroup.length; i++) { + const idx = newModel.declarations.indexOf(newEnumsGroup[i]!); + if (idx >= 0) { + newModel.declarations.splice(idx, 1); + } + } + + // Update all field references in newModel to point to the kept enum + for (const newDM of newDataModels) { + for (const field of newDM.fields) { + if (field.$type !== 'DataField') continue; + const ref = field.type.reference?.ref; + if (ref && newEnumsGroup.includes(ref as Enum)) { + (field.type as any).reference = { + ref: keepEnum, + $refText: keepEnum.name, + }; + } + } + } + + console.log( + colors.gray( + `Consolidated enum${newEnumsGroup.length > 1 ? 's' : ''} ${newEnumsGroup.map((e) => e.name).join(', ')} → ${oldEnum.name}`, + ), + ); + } +} diff --git a/packages/cli/src/actions/pull/provider/index.ts b/packages/cli/src/actions/pull/provider/index.ts new file mode 100644 index 000000000..7c93746d4 --- /dev/null +++ b/packages/cli/src/actions/pull/provider/index.ts @@ -0,0 +1,13 @@ +import type { DataSourceProviderType } from '@zenstackhq/schema'; +export * from './provider'; + +import { mysql } from './mysql'; +import { postgresql } from './postgresql'; +import type { IntrospectionProvider } from './provider'; +import { sqlite } from './sqlite'; + +export const providers: Record = { + mysql, + postgresql, + sqlite, +}; diff --git a/packages/cli/src/actions/pull/provider/mysql.ts b/packages/cli/src/actions/pull/provider/mysql.ts new file mode 100644 index 000000000..1c8124435 --- /dev/null +++ b/packages/cli/src/actions/pull/provider/mysql.ts @@ -0,0 +1,601 @@ +import type { Attribute, BuiltinType } from '@zenstackhq/language/ast'; +import { DataFieldAttributeFactory } from '@zenstackhq/language/factory'; +import { getAttributeRef, getDbName, getFunctionRef, normalizeDecimalDefault, normalizeFloatDefault } from '../utils'; +import type { IntrospectedEnum, IntrospectedSchema, IntrospectedTable, IntrospectionProvider } from './provider'; +import { CliError } from '../../../cli-error'; +import { resolveNameCasing } from '../casing'; + +// Note: We dynamically import mysql2 inside the async function to avoid +// requiring it at module load time for environments that don't use MySQL. + +function normalizeGenerationExpression(typeDef: string): string { + // MySQL may include character set introducers in generation expressions, e.g. `_utf8mb4' '`. + // Strip them to produce a stable, cleaner expression for `Unsupported("...")`. + // MySQL commonly returns generation expressions with SQL-style quote escaping (e.g. `\\'`), + // which would become an invalid ZModel string after the code generator escapes quotes again. + // Normalize it to raw quotes, letting the ZModel code generator re-escape appropriately. + return ( + typeDef + // Remove character set introducers, with or without escaped quotes. + .replace(/_([0-9A-Za-z_]+)\\?'/g, "'") + // Unescape SQL-style escaped single quotes in the expression. + .replace(/\\'/g, "'") + ); +} + +export const mysql: IntrospectionProvider = { + isSupportedFeature(feature) { + switch (feature) { + case 'NativeEnum': + return true; + case 'Schema': + default: + return false; + } + }, + getBuiltinType(type) { + const t = (type || '').toLowerCase().trim(); + + // MySQL doesn't have native array types + const isArray = false; + + switch (t) { + // integers + case 'tinyint': + case 'smallint': + case 'mediumint': + case 'int': + case 'integer': + return { type: 'Int', isArray }; + case 'bigint': + return { type: 'BigInt', isArray }; + + // decimals and floats + case 'decimal': + case 'numeric': + return { type: 'Decimal', isArray }; + case 'float': + case 'double': + case 'real': + return { type: 'Float', isArray }; + + // boolean (MySQL uses TINYINT(1) for boolean) + case 'boolean': + case 'bool': + return { type: 'Boolean', isArray }; + + // strings + case 'char': + case 'varchar': + case 'tinytext': + case 'text': + case 'mediumtext': + case 'longtext': + return { type: 'String', isArray }; + + // dates/times + case 'date': + case 'time': + case 'datetime': + case 'timestamp': + case 'year': + return { type: 'DateTime', isArray }; + + // binary + case 'binary': + case 'varbinary': + case 'tinyblob': + case 'blob': + case 'mediumblob': + case 'longblob': + return { type: 'Bytes', isArray }; + + // json + case 'json': + return { type: 'Json', isArray }; + + default: + // Handle ENUM type - MySQL returns enum values like "enum('val1','val2')" + if (t.startsWith('enum(')) { + return { type: 'String', isArray }; + } + // Handle SET type + if (t.startsWith('set(')) { + return { type: 'String', isArray }; + } + return { type: 'Unsupported' as const, isArray }; + } + }, + getDefaultDatabaseType(type: BuiltinType) { + switch (type) { + case 'String': + return { type: 'varchar', precision: 191 }; + case 'Boolean': + // Boolean maps to 'boolean' (our synthetic type from tinyint(1)) + // No precision needed since we handle the mapping in the query + return { type: 'boolean' }; + case 'Int': + return { type: 'int' }; + case 'BigInt': + return { type: 'bigint' }; + case 'Float': + return { type: 'double' }; + case 'Decimal': + return { type: 'decimal', precision: 65 }; + case 'DateTime': + return { type: 'datetime', precision: 3 }; + case 'Json': + return { type: 'json' }; + case 'Bytes': + return { type: 'longblob' }; + } + }, + async introspect(connectionString: string, options: { schemas: string[]; modelCasing: 'pascal' | 'camel' | 'snake' | 'none' }): Promise { + const mysql = await import('mysql2/promise'); + const connection = await mysql.createConnection(connectionString); + + try { + // Extract database name from connection string + const url = new URL(connectionString); + const databaseName = url.pathname.replace('/', ''); + + if (!databaseName) { + throw new CliError('Database name not found in connection string'); + } + + // Introspect tables + const [tableRows] = (await connection.execute(getTableIntrospectionQuery(), [databaseName])) as [ + IntrospectedTable[], + unknown, + ]; + const tables: IntrospectedTable[] = []; + + for (const row of tableRows) { + const columns = typeof row.columns === 'string' ? JSON.parse(row.columns) : row.columns; + const indexes = typeof row.indexes === 'string' ? JSON.parse(row.indexes) : row.indexes; + + // Sort columns by ordinal_position to preserve database column order + const sortedColumns = (columns || []) + .sort( + (a: { ordinal_position?: number }, b: { ordinal_position?: number }) => + (a.ordinal_position ?? 0) - (b.ordinal_position ?? 0) + ) + .map((col: any) => { + // MySQL enum datatype_name is synthetic (TableName_ColumnName). + // Apply model casing so it matches the cased enum_type. + if (col.datatype === 'enum' && col.datatype_name) { + return { ...col, datatype_name: resolveNameCasing(options.modelCasing, col.datatype_name).name }; + } + // Normalize generated column expressions for stable output. + if (col.computed && typeof col.datatype === 'string') { + return { ...col, datatype: normalizeGenerationExpression(col.datatype) }; + } + return col; + }); + + // Filter out auto-generated FK indexes (MySQL creates these automatically) + // Pattern: {Table}_{column}_fkey for single-column FK indexes + const filteredIndexes = (indexes || []).filter( + (idx: { name: string; columns: { name: string }[] }) => + !(idx.columns.length === 1 && idx.name === `${row.name}_${idx.columns[0]?.name}_fkey`) + ); + + tables.push({ + schema: '', // MySQL doesn't support multi-schema + name: row.name, + type: row.type as 'table' | 'view', + definition: row.definition, + columns: sortedColumns, + indexes: filteredIndexes, + }); + } + + // Introspect enums (MySQL stores enum values in column definitions) + const [enumRows] = (await connection.execute(getEnumIntrospectionQuery(), [databaseName])) as [ + { table_name: string; column_name: string; column_type: string }[], + unknown, + ]; + + const enums: IntrospectedEnum[] = enumRows.map((row) => { + // Parse enum values from column_type like "enum('val1','val2','val3')" + const values = parseEnumValues(row.column_type); + // MySQL doesn't have standalone enum types; the name is entirely + // synthetic (TableName_ColumnName). Apply model casing here so it + // arrives already cased — there is no raw DB name to @@map back to. + const syntheticName = `${row.table_name}_${row.column_name}`; + const { name } = resolveNameCasing(options.modelCasing, syntheticName); + return { + schema_name: '', // MySQL doesn't support multi-schema + enum_type: name, + values, + }; + }); + + return { tables, enums }; + } finally { + await connection.end(); + } + }, + getDefaultValue({ defaultValue, fieldType, datatype, datatype_name, services, enums }) { + const val = defaultValue.trim(); + + // Handle NULL early + if (val.toUpperCase() === 'NULL') { + return null; + } + + // Handle enum defaults + if (datatype === 'enum' && datatype_name) { + const enumDef = enums.find((e) => getDbName(e) === datatype_name); + if (enumDef) { + // Strip quotes from the value (MySQL returns 'value') + const enumValue = val.startsWith("'") && val.endsWith("'") ? val.slice(1, -1) : val; + const enumField = enumDef.fields.find((f) => getDbName(f) === enumValue); + if (enumField) { + return (ab) => ab.ReferenceExpr.setTarget(enumField); + } + } + } + + switch (fieldType) { + case 'DateTime': + if (/^CURRENT_TIMESTAMP(\(\d*\))?$/i.test(val) || val.toLowerCase() === 'current_timestamp()' || val.toLowerCase() === 'now()') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('now', services)); + } + // Fallback to string literal for other DateTime defaults + return (ab) => ab.StringLiteral.setValue(val); + + case 'Int': + case 'BigInt': + if (val.toLowerCase() === 'auto_increment') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('autoincrement', services)); + } + return (ab) => ab.NumberLiteral.setValue(val); + + case 'Float': + return normalizeFloatDefault(val); + + case 'Decimal': + return normalizeDecimalDefault(val); + + case 'Boolean': + return (ab) => ab.BooleanLiteral.setValue(val.toLowerCase() === 'true' || val === '1' || val === "b'1'"); + + case 'String': + if (val.toLowerCase() === 'uuid()') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('uuid', services)); + } + return (ab) => ab.StringLiteral.setValue(val); + case 'Json': + return (ab) => ab.StringLiteral.setValue(val); + case 'Bytes': + return (ab) => ab.StringLiteral.setValue(val); + } + + // Handle function calls (e.g., uuid(), now()) + if (val.includes('(') && val.includes(')')) { + return (ab) => + ab.InvocationExpr.setFunction(getFunctionRef('dbgenerated', services)).addArg((a) => + a.setValue((v) => v.StringLiteral.setValue(val)), + ); + } + + console.warn(`Unsupported default value type: "${defaultValue}" for field type "${fieldType}". Skipping default value.`); + return null; + }, + + getFieldAttributes({ fieldName, fieldType, datatype, length, precision, services }) { + const factories: DataFieldAttributeFactory[] = []; + + // Add @updatedAt for DateTime fields named updatedAt or updated_at + if (fieldType === 'DateTime' && (fieldName.toLowerCase() === 'updatedat' || fieldName.toLowerCase() === 'updated_at')) { + factories.push(new DataFieldAttributeFactory().setDecl(getAttributeRef('@updatedAt', services))); + } + + // Add @db.* attribute if the datatype differs from the default + const dbAttr = services.shared.workspace.IndexManager.allElements('Attribute').find( + (d) => d.name.toLowerCase() === `@db.${datatype.toLowerCase()}`, + )?.node as Attribute | undefined; + + const defaultDatabaseType = this.getDefaultDatabaseType(fieldType as BuiltinType); + + if ( + dbAttr && + defaultDatabaseType && + (defaultDatabaseType.type !== datatype || + (defaultDatabaseType.precision && + defaultDatabaseType.precision !== (length ?? precision))) + ) { + const dbAttrFactory = new DataFieldAttributeFactory().setDecl(dbAttr); + const sizeValue = length ?? precision; + if (sizeValue !== undefined && sizeValue !== null) { + dbAttrFactory.addArg((a) => a.NumberLiteral.setValue(sizeValue)); + } + factories.push(dbAttrFactory); + } + + return factories; + }, +}; + +function getTableIntrospectionQuery() { + // Note: We use subqueries with ORDER BY before JSON_ARRAYAGG to ensure ordering + // since MySQL < 8.0.21 doesn't support ORDER BY inside JSON_ARRAYAGG. + // MySQL doesn't support multi-schema, so we don't include schema in the result. + return ` +-- Main query: one row per table/view with columns and indexes as nested JSON arrays. +-- Uses INFORMATION_SCHEMA which is MySQL's standard metadata catalog. +SELECT + t.TABLE_NAME AS \`name\`, -- table or view name + CASE t.TABLE_TYPE -- map MySQL table type strings to our internal types + WHEN 'BASE TABLE' THEN 'table' + WHEN 'VIEW' THEN 'view' + ELSE NULL + END AS \`type\`, + CASE -- for views, retrieve the SQL definition + WHEN t.TABLE_TYPE = 'VIEW' THEN v.VIEW_DEFINITION + ELSE NULL + END AS \`definition\`, + + -- ===== COLUMNS subquery ===== + -- Wraps an ordered subquery in JSON_ARRAYAGG to produce a JSON array of column objects. + ( + SELECT JSON_ARRAYAGG(col_json) + FROM ( + SELECT JSON_OBJECT( + 'ordinal_position', c.ORDINAL_POSITION, -- column position (used for sorting) + 'name', c.COLUMN_NAME, -- column name + + -- datatype: for generated/computed columns, construct the full DDL-like type definition + -- (e.g., "int GENERATED ALWAYS AS (col1 + col2) STORED") so it can be rendered as + -- Unsupported("..."); special-case tinyint(1) as 'boolean' (MySQL's boolean convention); + -- otherwise use the DATA_TYPE (e.g., 'int', 'varchar', 'datetime'). + 'datatype', CASE + WHEN c.GENERATION_EXPRESSION IS NOT NULL AND c.GENERATION_EXPRESSION != '' THEN + CONCAT( + c.COLUMN_TYPE, + ' GENERATED ALWAYS AS (', + c.GENERATION_EXPRESSION, + ') ', + CASE + WHEN c.EXTRA LIKE '%STORED GENERATED%' THEN 'STORED' + ELSE 'VIRTUAL' + END + ) + WHEN c.DATA_TYPE = 'tinyint' AND c.COLUMN_TYPE = 'tinyint(1)' THEN 'boolean' + ELSE c.DATA_TYPE + END, + + -- datatype_name: for enum columns, generate a synthetic name "TableName_ColumnName" + -- (MySQL doesn't have named enum types like PostgreSQL) + 'datatype_name', CASE + WHEN c.DATA_TYPE = 'enum' THEN CONCAT(t.TABLE_NAME, '_', c.COLUMN_NAME) + ELSE NULL + END, + + 'datatype_schema', '', -- MySQL doesn't support multi-schema + 'length', c.CHARACTER_MAXIMUM_LENGTH, -- max length for string types (e.g., VARCHAR(255) -> 255) + 'precision', COALESCE(c.NUMERIC_PRECISION, c.DATETIME_PRECISION), -- numeric or datetime precision + + 'nullable', c.IS_NULLABLE = 'YES', -- true if column allows NULL + + -- default: for auto_increment columns, report 'auto_increment' instead of NULL; + -- otherwise use the COLUMN_DEFAULT value + 'default', CASE + WHEN c.EXTRA LIKE '%auto_increment%' THEN 'auto_increment' + ELSE c.COLUMN_DEFAULT + END, + + 'pk', c.COLUMN_KEY = 'PRI', -- true if column is part of the primary key + + -- unique: true if the column has a single-column unique index. + -- COLUMN_KEY = 'UNI' covers most cases, but may not be set when the column + -- also participates in other indexes (showing 'MUL' instead on some MySQL versions). + -- Also check INFORMATION_SCHEMA.STATISTICS for single-column unique indexes + -- (NON_UNIQUE = 0) to match the PostgreSQL introspection behavior. + 'unique', ( + c.COLUMN_KEY = 'UNI' + OR EXISTS ( + SELECT 1 + FROM INFORMATION_SCHEMA.STATISTICS s_uni + WHERE s_uni.TABLE_SCHEMA = c.TABLE_SCHEMA + AND s_uni.TABLE_NAME = c.TABLE_NAME + AND s_uni.COLUMN_NAME = c.COLUMN_NAME + AND s_uni.NON_UNIQUE = 0 + AND s_uni.INDEX_NAME != 'PRIMARY' + AND ( + SELECT COUNT(*) + FROM INFORMATION_SCHEMA.STATISTICS s_cnt + WHERE s_cnt.TABLE_SCHEMA = s_uni.TABLE_SCHEMA + AND s_cnt.TABLE_NAME = s_uni.TABLE_NAME + AND s_cnt.INDEX_NAME = s_uni.INDEX_NAME + ) = 1 + ) + ), + 'unique_name', ( + SELECT COALESCE( + CASE WHEN c.COLUMN_KEY = 'UNI' THEN c.COLUMN_NAME ELSE NULL END, + ( + SELECT s_uni.INDEX_NAME + FROM INFORMATION_SCHEMA.STATISTICS s_uni + WHERE s_uni.TABLE_SCHEMA = c.TABLE_SCHEMA + AND s_uni.TABLE_NAME = c.TABLE_NAME + AND s_uni.COLUMN_NAME = c.COLUMN_NAME + AND s_uni.NON_UNIQUE = 0 + AND s_uni.INDEX_NAME != 'PRIMARY' + AND ( + SELECT COUNT(*) + FROM INFORMATION_SCHEMA.STATISTICS s_cnt + WHERE s_cnt.TABLE_SCHEMA = s_uni.TABLE_SCHEMA + AND s_cnt.TABLE_NAME = s_uni.TABLE_NAME + AND s_cnt.INDEX_NAME = s_uni.INDEX_NAME + ) = 1 + LIMIT 1 + ) + ) + ), + + -- computed: true if column has a generation expression (virtual or stored) + 'computed', c.GENERATION_EXPRESSION IS NOT NULL AND c.GENERATION_EXPRESSION != '', + + -- options: for enum columns, the full COLUMN_TYPE string (e.g., "enum('a','b','c')") + -- which gets parsed into individual values later + 'options', CASE + WHEN c.DATA_TYPE = 'enum' THEN c.COLUMN_TYPE + ELSE NULL + END, + + -- Foreign key info (NULL if column is not part of a FK) + 'foreign_key_schema', NULL, -- MySQL doesn't support cross-schema FKs here + 'foreign_key_table', kcu_fk.REFERENCED_TABLE_NAME, -- referenced table + 'foreign_key_column', kcu_fk.REFERENCED_COLUMN_NAME, -- referenced column + 'foreign_key_name', kcu_fk.CONSTRAINT_NAME, -- FK constraint name + 'foreign_key_on_update', rc.UPDATE_RULE, -- referential action on update (CASCADE, SET NULL, etc.) + 'foreign_key_on_delete', rc.DELETE_RULE -- referential action on delete + ) AS col_json + + FROM INFORMATION_SCHEMA.COLUMNS c -- one row per column in the database + + -- Join KEY_COLUMN_USAGE to find foreign key references for this column. + -- Filter to only FK entries (REFERENCED_TABLE_NAME IS NOT NULL). + LEFT JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu_fk + ON c.TABLE_SCHEMA = kcu_fk.TABLE_SCHEMA + AND c.TABLE_NAME = kcu_fk.TABLE_NAME + AND c.COLUMN_NAME = kcu_fk.COLUMN_NAME + AND kcu_fk.REFERENCED_TABLE_NAME IS NOT NULL + + -- Join REFERENTIAL_CONSTRAINTS to get ON UPDATE / ON DELETE rules for the FK. + LEFT JOIN INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc + ON kcu_fk.CONSTRAINT_SCHEMA = rc.CONSTRAINT_SCHEMA + AND kcu_fk.CONSTRAINT_NAME = rc.CONSTRAINT_NAME + + WHERE c.TABLE_SCHEMA = t.TABLE_SCHEMA + AND c.TABLE_NAME = t.TABLE_NAME + ORDER BY c.ORDINAL_POSITION -- preserve original column order + ) AS cols_ordered + ) AS \`columns\`, + + -- ===== INDEXES subquery ===== + -- Aggregates all indexes for this table into a JSON array. + ( + SELECT JSON_ARRAYAGG(idx_json) + FROM ( + SELECT JSON_OBJECT( + 'name', s.INDEX_NAME, -- index name (e.g., 'PRIMARY', 'idx_email') + 'method', s.INDEX_TYPE, -- index type (e.g., 'BTREE', 'HASH', 'FULLTEXT') + 'unique', s.NON_UNIQUE = 0, -- NON_UNIQUE=0 means it IS unique + 'primary', s.INDEX_NAME = 'PRIMARY', -- MySQL names the PK index 'PRIMARY' + 'valid', TRUE, -- MySQL doesn't expose index validity status + 'ready', TRUE, -- MySQL doesn't expose index readiness status + 'partial', FALSE, -- MySQL doesn't support partial indexes + 'predicate', NULL, -- no WHERE clause on indexes in MySQL + + -- Index columns: nested subquery for columns in this index + 'columns', ( + SELECT JSON_ARRAYAGG(idx_col_json) + FROM ( + SELECT JSON_OBJECT( + 'name', s2.COLUMN_NAME, -- column name in the index + 'expression', NULL, -- MySQL doesn't expose expression indexes via STATISTICS + -- COLLATION: 'A' = ascending, 'D' = descending, NULL = not sorted + 'order', CASE s2.COLLATION WHEN 'A' THEN 'ASC' WHEN 'D' THEN 'DESC' ELSE NULL END, + 'nulls', NULL -- MySQL doesn't expose NULLS FIRST/LAST + ) AS idx_col_json + FROM INFORMATION_SCHEMA.STATISTICS s2 -- one row per column per index + WHERE s2.TABLE_SCHEMA = s.TABLE_SCHEMA + AND s2.TABLE_NAME = s.TABLE_NAME + AND s2.INDEX_NAME = s.INDEX_NAME + ORDER BY s2.SEQ_IN_INDEX -- preserve column order within the index + ) AS idx_cols_ordered + ) + ) AS idx_json + FROM ( + -- Deduplicate: STATISTICS has one row per (index, column), but we need one row per index. + -- DISTINCT on INDEX_NAME gives us one entry per index with its metadata. + SELECT DISTINCT INDEX_NAME, INDEX_TYPE, NON_UNIQUE, TABLE_SCHEMA, TABLE_NAME + FROM INFORMATION_SCHEMA.STATISTICS + WHERE TABLE_SCHEMA = t.TABLE_SCHEMA AND TABLE_NAME = t.TABLE_NAME + ) s + ) AS idxs_ordered + ) AS \`indexes\` + +-- === Main FROM: INFORMATION_SCHEMA.TABLES lists all tables and views === +FROM INFORMATION_SCHEMA.TABLES t +-- Join VIEWS to get VIEW_DEFINITION for view tables +LEFT JOIN INFORMATION_SCHEMA.VIEWS v + ON t.TABLE_SCHEMA = v.TABLE_SCHEMA AND t.TABLE_NAME = v.TABLE_NAME +WHERE t.TABLE_SCHEMA = ? -- only the target database + AND t.TABLE_TYPE IN ('BASE TABLE', 'VIEW') -- exclude system tables like SYSTEM VIEW + AND t.TABLE_NAME <> '_prisma_migrations' -- exclude Prisma migration tracking table +ORDER BY t.TABLE_NAME; +`; +} + +function getEnumIntrospectionQuery() { + // MySQL doesn't have standalone enum types like PostgreSQL's CREATE TYPE. + // Instead, enum values are embedded in column definitions (e.g., COLUMN_TYPE = "enum('a','b','c')"). + // This query finds all enum columns so we can extract their allowed values. + return ` +SELECT + c.TABLE_NAME AS table_name, -- table containing the enum column + c.COLUMN_NAME AS column_name, -- column name + c.COLUMN_TYPE AS column_type -- full type string including values (e.g., "enum('val1','val2')") +FROM INFORMATION_SCHEMA.COLUMNS c +WHERE c.TABLE_SCHEMA = ? -- only the target database + AND c.DATA_TYPE = 'enum' -- only enum columns +ORDER BY c.TABLE_NAME, c.COLUMN_NAME; +`; +} + +/** + * Parse enum values from MySQL COLUMN_TYPE string like "enum('val1','val2','val3')" + */ +function parseEnumValues(columnType: string): string[] { + // Match the content inside enum(...) + const match = columnType.match(/^enum\((.+)\)$/i); + if (!match || !match[1]) return []; + + const valuesString = match[1]; + const values: string[] = []; + + // Parse quoted values, handling escaped quotes + let current = ''; + let inQuote = false; + let i = 0; + + while (i < valuesString.length) { + const char = valuesString[i]; + + if (char === "'" && !inQuote) { + inQuote = true; + i++; + continue; + } + + if (char === "'" && inQuote) { + // Check for escaped quote ('') + if (valuesString[i + 1] === "'") { + current += "'"; + i += 2; + continue; + } + // End of value + values.push(current); + current = ''; + inQuote = false; + i++; + // Skip comma and any whitespace + while (i < valuesString.length && (valuesString[i] === ',' || valuesString[i] === ' ')) { + i++; + } + continue; + } + + if (inQuote) { + current += char; + } + i++; + } + + return values; +} diff --git a/packages/cli/src/actions/pull/provider/postgresql.ts b/packages/cli/src/actions/pull/provider/postgresql.ts new file mode 100644 index 000000000..6bfc9d231 --- /dev/null +++ b/packages/cli/src/actions/pull/provider/postgresql.ts @@ -0,0 +1,663 @@ +import type { Attribute, BuiltinType, Enum, Expression } from '@zenstackhq/language/ast'; +import { AstFactory, DataFieldAttributeFactory, ExpressionBuilder } from '@zenstackhq/language/factory'; +import { Client } from 'pg'; +import { getAttributeRef, getDbName, getFunctionRef, normalizeDecimalDefault, normalizeFloatDefault } from '../utils'; +import type { IntrospectedEnum, IntrospectedSchema, IntrospectedTable, IntrospectionProvider } from './provider'; +import type { ZModelServices } from '@zenstackhq/language'; +import { CliError } from '../../../cli-error'; + +/** + * Maps PostgreSQL internal type names to their standard SQL names for comparison. + * This is used to normalize type names when checking against default database types. + */ +const pgTypnameToStandard: Record = { + int2: 'smallint', + int4: 'integer', + int8: 'bigint', + float4: 'real', + float8: 'double precision', + bool: 'boolean', + bpchar: 'character', + numeric: 'decimal', +}; + +/** + * Standard bit widths for integer/float types that shouldn't be added as precision arguments. + * PostgreSQL returns these as precision values, but they're implicit for the type. + */ +const standardTypePrecisions: Record = { + int2: 16, + smallint: 16, + int4: 32, + integer: 32, + int8: 64, + bigint: 64, + float4: 24, + real: 24, + float8: 53, + 'double precision': 53, +}; + +/** + * Maps PostgreSQL typnames (from pg_type.typname) to ZenStack native type attribute names. + * PostgreSQL introspection returns internal type names like 'int2', 'int4', 'float8', 'bpchar', + * but ZenStack attributes are named @db.SmallInt, @db.Integer, @db.DoublePrecision, @db.Char, etc. + */ +const pgTypnameToZenStackNativeType: Record = { + // integers + int2: 'SmallInt', + smallint: 'SmallInt', + int4: 'Integer', + integer: 'Integer', + int8: 'BigInt', + bigint: 'BigInt', + + // decimals and floats + numeric: 'Decimal', + decimal: 'Decimal', + float4: 'Real', + real: 'Real', + float8: 'DoublePrecision', + 'double precision': 'DoublePrecision', + + // boolean + bool: 'Boolean', + boolean: 'Boolean', + + // strings + text: 'Text', + varchar: 'VarChar', + 'character varying': 'VarChar', + bpchar: 'Char', + character: 'Char', + + // uuid + uuid: 'Uuid', + + // dates/times + date: 'Date', + time: 'Time', + timetz: 'Timetz', + timestamp: 'Timestamp', + timestamptz: 'Timestamptz', + + // binary + bytea: 'ByteA', + + // json + json: 'Json', + jsonb: 'JsonB', + + // xml + xml: 'Xml', + + // network types + inet: 'Inet', + + // bit strings + bit: 'Bit', + varbit: 'VarBit', + + // oid + oid: 'Oid', + + // money + money: 'Money', + + // citext extension + citext: 'Citext', +}; + +export const postgresql: IntrospectionProvider = { + isSupportedFeature(feature) { + const supportedFeatures = ['Schema', 'NativeEnum']; + return supportedFeatures.includes(feature); + }, + getBuiltinType(type) { + const t = (type || '').toLowerCase(); + + const isArray = t.startsWith('_'); + + switch (t.replace(/^_/, '')) { + // integers + case 'int2': + case 'smallint': + case 'int4': + case 'integer': + return { type: 'Int', isArray }; + case 'int8': + case 'bigint': + return { type: 'BigInt', isArray }; + + // decimals and floats + case 'numeric': + case 'decimal': + return { type: 'Decimal', isArray }; + case 'float4': + case 'real': + case 'float8': + case 'double precision': + return { type: 'Float', isArray }; + + // boolean + case 'bool': + case 'boolean': + return { type: 'Boolean', isArray }; + + // strings + case 'text': + case 'varchar': + case 'bpchar': + case 'character varying': + case 'character': + return { type: 'String', isArray }; + + // uuid + case 'uuid': + return { type: 'String', isArray }; + + // dates/times + case 'date': + case 'time': + case 'timetz': + case 'timestamp': + case 'timestamptz': + return { type: 'DateTime', isArray }; + + // binary + case 'bytea': + return { type: 'Bytes', isArray }; + + // json + case 'json': + case 'jsonb': + return { type: 'Json', isArray }; + default: + return { type: 'Unsupported' as const, isArray }; + } + }, + async introspect(connectionString: string, options: { schemas: string[]; modelCasing: 'pascal' | 'camel' | 'snake' | 'none' }): Promise { + const client = new Client({ connectionString }); + await client.connect(); + + try { + const { rows: tables } = await client.query(tableIntrospectionQuery); + const { rows: enums } = await client.query(enumIntrospectionQuery); + + // Filter tables and enums to only include those from the selected schemas + const filteredTables = tables.filter((t) => options.schemas.includes(t.schema)); + const filteredEnums = enums.filter((e) => options.schemas.includes(e.schema_name)); + + return { + enums: filteredEnums, + tables: filteredTables, + }; + } finally { + await client.end(); + } + }, + getDefaultDatabaseType(type: BuiltinType) { + switch (type) { + case 'String': + return { type: 'text' }; + case 'Boolean': + return { type: 'boolean' }; + case 'Int': + return { type: 'integer' }; + case 'BigInt': + return { type: 'bigint' }; + case 'Float': + return { type: 'double precision' }; + case 'Decimal': + return { type: 'decimal' }; + case 'DateTime': + return { type: 'timestamp', precision: 3 }; + case 'Json': + return { type: 'jsonb' }; + case 'Bytes': + return { type: 'bytea' }; + } + }, + getDefaultValue({ defaultValue, fieldType, datatype, datatype_name, services, enums }) { + const val = defaultValue.trim(); + + // Handle enum defaults (PostgreSQL returns 'value'::enum_type) + if (datatype === 'enum' && datatype_name) { + const enumDef = enums.find((e) => getDbName(e) === datatype_name); + if (enumDef) { + // Extract the enum value from the default (format: 'VALUE'::"enum_type") + const enumValue = val.replace(/'/g, '').split('::')[0]?.trim(); + const enumField = enumDef.fields.find((f) => getDbName(f) === enumValue); + if (enumField) { + return (ab) => ab.ReferenceExpr.setTarget(enumField); + } + } + // Fall through to typeCastingConvert if datatype_name lookup fails + return typeCastingConvert({defaultValue,enums,val,services}); + } + + switch (fieldType) { + case 'DateTime': + if (val === 'CURRENT_TIMESTAMP' || val === 'now()') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('now', services)); + } + + if (val.includes('::')) { + return typeCastingConvert({defaultValue,enums,val,services}); + } + + // Fallback to string literal for other DateTime defaults + return (ab) => ab.StringLiteral.setValue(val); + + case 'Int': + case 'BigInt': + if (val.startsWith('nextval(')) { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('autoincrement', services)); + } + + if (val.includes('::')) { + return typeCastingConvert({defaultValue,enums,val,services}); + } + return (ab) => ab.NumberLiteral.setValue(val); + + case 'Float': + if (val.includes('::')) { + return typeCastingConvert({defaultValue,enums,val,services}); + } + return normalizeFloatDefault(val); + + case 'Decimal': + if (val.includes('::')) { + return typeCastingConvert({defaultValue,enums,val,services}); + } + return normalizeDecimalDefault(val); + + case 'Boolean': + return (ab) => ab.BooleanLiteral.setValue(val === 'true'); + + case 'String': + if (val.includes('::')) { + return typeCastingConvert({defaultValue,enums,val,services}); + } + + if (val.startsWith("'") && val.endsWith("'")) { + return (ab) => ab.StringLiteral.setValue(val.slice(1, -1).replace(/''/g, "'")); + } + return (ab) => ab.StringLiteral.setValue(val); + case 'Json': + if (val.includes('::')) { + return typeCastingConvert({defaultValue,enums,val,services}); + } + return (ab) => ab.StringLiteral.setValue(val); + case 'Bytes': + if (val.includes('::')) { + return typeCastingConvert({defaultValue,enums,val,services}); + } + return (ab) => ab.StringLiteral.setValue(val); + } + + if (val.includes('(') && val.includes(')')) { + return (ab) => + ab.InvocationExpr.setFunction(getFunctionRef('dbgenerated', services)).addArg((a) => + a.setValue((v) => v.StringLiteral.setValue(val)), + ); + } + + console.warn(`Unsupported default value type: "${defaultValue}" for field type "${fieldType}". Skipping default value.`); + return null; + }, + + getFieldAttributes({ fieldName, fieldType, datatype, length, precision, services }) { + const factories: DataFieldAttributeFactory[] = []; + + // Add @updatedAt for DateTime fields named updatedAt or updated_at + if (fieldType === 'DateTime' && (fieldName.toLowerCase() === 'updatedat' || fieldName.toLowerCase() === 'updated_at')) { + factories.push(new DataFieldAttributeFactory().setDecl(getAttributeRef('@updatedAt', services))); + } + + // Map PostgreSQL typname to ZenStack native type attribute name + // PostgreSQL returns typnames like 'int2', 'float8', 'bpchar', but ZenStack attributes + // are named @db.SmallInt, @db.DoublePrecision, @db.Char, etc. + const nativeTypeName = pgTypnameToZenStackNativeType[datatype.toLowerCase()] ?? datatype; + + // Add @db.* attribute if the datatype differs from the default + const dbAttr = services.shared.workspace.IndexManager.allElements('Attribute').find( + (d) => d.name.toLowerCase() === `@db.${nativeTypeName.toLowerCase()}`, + )?.node as Attribute | undefined; + + const defaultDatabaseType = this.getDefaultDatabaseType(fieldType as BuiltinType); + + // Normalize datatype for comparison (e.g., 'int4' -> 'integer') + const normalizedDatatype = pgTypnameToStandard[datatype.toLowerCase()] ?? datatype.toLowerCase(); + + // Check if the precision is the standard bit width for this type (shouldn't be added) + const standardPrecision = standardTypePrecisions[datatype.toLowerCase()]; + const isStandardPrecision = standardPrecision !== undefined && precision === standardPrecision; + + if ( + dbAttr && + defaultDatabaseType && + (defaultDatabaseType.type !== normalizedDatatype || + (defaultDatabaseType.precision && + defaultDatabaseType.precision !== (length ?? precision))) + ) { + const dbAttrFactory = new DataFieldAttributeFactory().setDecl(dbAttr); + // Only add length/precision if it's meaningful (not the standard bit width for the type) + if ((length || precision) && !isStandardPrecision) { + dbAttrFactory.addArg((a) => a.NumberLiteral.setValue(length! || precision!)); + } + factories.push(dbAttrFactory); + } + + return factories; + }, +}; + +const enumIntrospectionQuery = ` +SELECT + n.nspname AS schema_name, -- schema the enum belongs to (e.g., 'public') + t.typname AS enum_type, -- enum type name as defined in CREATE TYPE + coalesce(json_agg(e.enumlabel ORDER BY e.enumsortorder), '[]') AS values -- ordered list of enum labels as JSON array +FROM pg_type t -- pg_type: catalog of all data types +JOIN pg_enum e ON t.oid = e.enumtypid -- pg_enum: one row per enum label; join to get labels for this enum type +JOIN pg_namespace n ON n.oid = t.typnamespace -- pg_namespace: schema info; join to get the schema name +GROUP BY schema_name, enum_type -- one row per enum type, with all labels aggregated +ORDER BY schema_name, enum_type;`; + +const tableIntrospectionQuery = ` +-- Main query: one row per table/view with columns and indexes as nested JSON arrays. +-- Joins pg_class (tables/views) with pg_namespace (schemas). +SELECT + "ns"."nspname" AS "schema", -- schema name (e.g., 'public') + "cls"."relname" AS "name", -- table or view name + CASE "cls"."relkind" -- relkind: 'r' = ordinary table, 'v' = view + WHEN 'r' THEN 'table' + WHEN 'v' THEN 'view' + ELSE NULL + END AS "type", + CASE -- for views, retrieve the SQL definition + WHEN "cls"."relkind" = 'v' THEN pg_get_viewdef("cls"."oid", true) + ELSE NULL + END AS "definition", + + -- ===== COLUMNS subquery ===== + -- Aggregates all columns for this table into a JSON array. + ( + SELECT coalesce(json_agg(agg), '[]') + FROM ( + SELECT + "att"."attname" AS "name", -- column name + + -- datatype: if the type is an enum, report 'enum'; + -- if the column is generated/computed, construct the full DDL-like type definition + -- (e.g., "text GENERATED ALWAYS AS (expr) STORED") so it can be rendered as Unsupported("..."); + -- otherwise use the pg_type name. + CASE + WHEN EXISTS ( + SELECT 1 FROM "pg_catalog"."pg_enum" AS "e" + WHERE "e"."enumtypid" = "typ"."oid" + ) THEN 'enum' + WHEN "att"."attgenerated" != '' THEN + format_type("att"."atttypid", "att"."atttypmod") + || ' GENERATED ALWAYS AS (' + || pg_get_expr("def"."adbin", "def"."adrelid") + || ') ' + || CASE "att"."attgenerated" + WHEN 's' THEN 'STORED' + WHEN 'v' THEN 'VIRTUAL' + ELSE 'STORED' + END + ELSE "typ"."typname"::text -- internal type name (e.g., 'int4', 'varchar', 'text'); cast to text to prevent CASE from coercing result to name type (max 63 chars) + END AS "datatype", + + -- datatype_name: for enums only, the actual enum type name (used to look up the enum definition) + CASE + WHEN EXISTS ( + SELECT 1 FROM "pg_catalog"."pg_enum" AS "e" + WHERE "e"."enumtypid" = "typ"."oid" + ) THEN "typ"."typname" + ELSE NULL + END AS "datatype_name", + + "tns"."nspname" AS "datatype_schema", -- schema where the data type is defined + "c"."character_maximum_length" AS "length", -- max length for char/varchar types (from information_schema) + COALESCE("c"."numeric_precision", "c"."datetime_precision") AS "precision", -- numeric or datetime precision + + -- Foreign key info (NULL if column is not part of a FK constraint) + "fk_ns"."nspname" AS "foreign_key_schema", -- schema of the referenced table + "fk_cls"."relname" AS "foreign_key_table", -- referenced table name + "fk_att"."attname" AS "foreign_key_column", -- referenced column name + "fk_con"."conname" AS "foreign_key_name", -- FK constraint name + + -- FK referential actions: decode single-char codes to human-readable strings + CASE "fk_con"."confupdtype" + WHEN 'a' THEN 'NO ACTION' + WHEN 'r' THEN 'RESTRICT' + WHEN 'c' THEN 'CASCADE' + WHEN 'n' THEN 'SET NULL' + WHEN 'd' THEN 'SET DEFAULT' + ELSE NULL + END AS "foreign_key_on_update", + CASE "fk_con"."confdeltype" + WHEN 'a' THEN 'NO ACTION' + WHEN 'r' THEN 'RESTRICT' + WHEN 'c' THEN 'CASCADE' + WHEN 'n' THEN 'SET NULL' + WHEN 'd' THEN 'SET DEFAULT' + ELSE NULL + END AS "foreign_key_on_delete", + + -- pk: true if this column is part of the table's primary key constraint + "pk_con"."conkey" IS NOT NULL AS "pk", + + -- unique: true if the column has a single-column UNIQUE constraint OR a single-column unique index + ( + -- Check for a single-column UNIQUE constraint (contype = 'u') + EXISTS ( + SELECT 1 + FROM "pg_catalog"."pg_constraint" AS "u_con" + WHERE "u_con"."contype" = 'u' -- 'u' = unique constraint + AND "u_con"."conrelid" = "cls"."oid" -- on this table + AND array_length("u_con"."conkey", 1) = 1 -- single-column only + AND "att"."attnum" = ANY ("u_con"."conkey") -- this column is in the constraint + ) + OR + -- Check for a single-column unique index (may exist without an explicit constraint) + EXISTS ( + SELECT 1 + FROM "pg_catalog"."pg_index" AS "u_idx" + WHERE "u_idx"."indrelid" = "cls"."oid" -- on this table + AND "u_idx"."indisunique" = TRUE -- it's a unique index + AND "u_idx"."indnkeyatts" = 1 -- single key column + AND "att"."attnum" = ANY ("u_idx"."indkey"::int2[]) -- this column is the key + ) + ) AS "unique", + + -- unique_name: the name of the unique constraint or index (whichever exists first) + ( + SELECT COALESCE( + -- Try constraint name first + ( + SELECT "u_con"."conname" + FROM "pg_catalog"."pg_constraint" AS "u_con" + WHERE "u_con"."contype" = 'u' + AND "u_con"."conrelid" = "cls"."oid" + AND array_length("u_con"."conkey", 1) = 1 + AND "att"."attnum" = ANY ("u_con"."conkey") + LIMIT 1 + ), + -- Fall back to unique index name + ( + SELECT "u_idx_cls"."relname" + FROM "pg_catalog"."pg_index" AS "u_idx" + JOIN "pg_catalog"."pg_class" AS "u_idx_cls" ON "u_idx"."indexrelid" = "u_idx_cls"."oid" + WHERE "u_idx"."indrelid" = "cls"."oid" + AND "u_idx"."indisunique" = TRUE + AND "u_idx"."indnkeyatts" = 1 + AND "att"."attnum" = ANY ("u_idx"."indkey"::int2[]) + LIMIT 1 + ) + ) + ) AS "unique_name", + + "att"."attgenerated" != '' AS "computed", -- true if column is a generated/computed column + -- For generated columns, pg_attrdef stores the generation expression (not a default), + -- so we must null it out to avoid emitting a spurious @default(dbgenerated(...)) attribute. + CASE + WHEN "att"."attgenerated" != '' THEN NULL + ELSE pg_get_expr("def"."adbin", "def"."adrelid") + END AS "default", -- column default expression as text (e.g., 'nextval(...)', '0', 'now()') + "att"."attnotnull" != TRUE AS "nullable", -- true if column allows NULL values + + -- options: for enum columns, aggregates all allowed enum labels into a JSON array + coalesce( + ( + SELECT json_agg("enm"."enumlabel") AS "o" + FROM "pg_catalog"."pg_enum" AS "enm" + WHERE "enm"."enumtypid" = "typ"."oid" + ), + '[]' + ) AS "options" + + -- === FROM / JOINs for the columns subquery === + + -- pg_attribute: one row per table column (attnum >= 0 excludes system columns) + FROM "pg_catalog"."pg_attribute" AS "att" + + -- pg_type: data type of the column (e.g., int4, text, custom_enum) + INNER JOIN "pg_catalog"."pg_type" AS "typ" ON "typ"."oid" = "att"."atttypid" + + -- pg_namespace for the type: needed to determine which schema the type lives in + INNER JOIN "pg_catalog"."pg_namespace" AS "tns" ON "tns"."oid" = "typ"."typnamespace" + + -- information_schema.columns: provides length/precision info not easily available from pg_catalog + LEFT JOIN "information_schema"."columns" AS "c" ON "c"."table_schema" = "ns"."nspname" + AND "c"."table_name" = "cls"."relname" + AND "c"."column_name" = "att"."attname" + + -- pg_constraint (primary key): join on contype='p' to detect if column is part of PK + LEFT JOIN "pg_catalog"."pg_constraint" AS "pk_con" ON "pk_con"."contype" = 'p' + AND "pk_con"."conrelid" = "cls"."oid" + AND "att"."attnum" = ANY ("pk_con"."conkey") + + -- pg_constraint (foreign key): join on contype='f' to get FK details for this column + LEFT JOIN "pg_catalog"."pg_constraint" AS "fk_con" ON "fk_con"."contype" = 'f' + AND "fk_con"."conrelid" = "cls"."oid" + AND "att"."attnum" = ANY ("fk_con"."conkey") + + -- pg_class for FK target table: resolve the referenced table's OID to its name + LEFT JOIN "pg_catalog"."pg_class" AS "fk_cls" ON "fk_cls"."oid" = "fk_con"."confrelid" + + -- pg_namespace for FK target: get the schema of the referenced table + LEFT JOIN "pg_catalog"."pg_namespace" AS "fk_ns" ON "fk_ns"."oid" = "fk_cls"."relnamespace" + + -- pg_attribute for FK target column: resolve the referenced column number to its name. + -- Use array_position to correlate by position: find this source column's index in conkey, + -- then pick the referenced attnum at that same index from confkey. + -- This ensures composite FKs correctly map each source column to its corresponding target column. + LEFT JOIN "pg_catalog"."pg_attribute" AS "fk_att" ON "fk_att"."attrelid" = "fk_cls"."oid" + AND "fk_att"."attnum" = "fk_con"."confkey"[array_position("fk_con"."conkey", "att"."attnum")] + + -- pg_attrdef: column defaults; adbin contains the internal expression, decoded via pg_get_expr() + LEFT JOIN "pg_catalog"."pg_attrdef" AS "def" ON "def"."adrelid" = "cls"."oid" AND "def"."adnum" = "att"."attnum" + + WHERE + "att"."attrelid" = "cls"."oid" -- only columns belonging to this table + AND "att"."attnum" >= 0 -- exclude system columns (ctid, xmin, etc. have attnum < 0) + AND "att"."attisdropped" != TRUE -- exclude dropped (deleted) columns + ORDER BY "att"."attnum" -- preserve original column order + ) AS agg + ) AS "columns", + + -- ===== INDEXES subquery ===== + -- Aggregates all indexes for this table into a JSON array. + ( + SELECT coalesce(json_agg(agg), '[]') + FROM ( + SELECT + "idx_cls"."relname" AS "name", -- index name + "am"."amname" AS "method", -- access method (e.g., 'btree', 'hash', 'gin', 'gist') + "idx"."indisunique" AS "unique", -- true if unique index + "idx"."indisprimary" AS "primary", -- true if this is the PK index + "idx"."indisvalid" AS "valid", -- false during concurrent index builds + "idx"."indisready" AS "ready", -- true when index is ready for inserts + ("idx"."indpred" IS NOT NULL) AS "partial", -- true if index has a WHERE clause (partial index) + pg_get_expr("idx"."indpred", "idx"."indrelid") AS "predicate", -- the WHERE clause expression for partial indexes + + -- Index columns: iterate over each position in the index key array + ( + SELECT json_agg( + json_build_object( + -- 'name': column name, or for expression indexes the expression text + 'name', COALESCE("att"."attname", pg_get_indexdef("idx"."indexrelid", "s"."i", true)), + -- 'expression': non-null only for expression-based index columns (e.g., lower(name)) + 'expression', CASE WHEN "att"."attname" IS NULL THEN pg_get_indexdef("idx"."indexrelid", "s"."i", true) ELSE NULL END, + -- 'order': sort direction; bit 0 of indoption = 1 means DESC + 'order', CASE ((( "idx"."indoption"::int2[] )["s"."i"] & 1)) WHEN 1 THEN 'DESC' ELSE 'ASC' END, + -- 'nulls': null ordering; bit 1 of indoption = 1 means NULLS FIRST + 'nulls', CASE (((( "idx"."indoption"::int2[] )["s"."i"] >> 1) & 1)) WHEN 1 THEN 'NULLS FIRST' ELSE 'NULLS LAST' END + ) + ORDER BY "s"."i" -- preserve column order within the index + ) + -- generate_subscripts creates one row per index key position (1-based) + FROM generate_subscripts("idx"."indkey"::int2[], 1) AS "s"("i") + -- Join to pg_attribute to resolve column numbers to names + -- NULL attname means it's an expression index column + LEFT JOIN "pg_catalog"."pg_attribute" AS "att" + ON "att"."attrelid" = "cls"."oid" + AND "att"."attnum" = ("idx"."indkey"::int2[])["s"."i"] + ) AS "columns" + + FROM "pg_catalog"."pg_index" AS "idx" -- pg_index: one row per index + JOIN "pg_catalog"."pg_class" AS "idx_cls" ON "idx"."indexrelid" = "idx_cls"."oid" -- index's own pg_class entry (for the name) + JOIN "pg_catalog"."pg_am" AS "am" ON "idx_cls"."relam" = "am"."oid" -- access method catalog + WHERE "idx"."indrelid" = "cls"."oid" -- only indexes on this table + ORDER BY "idx_cls"."relname" + ) AS agg + ) AS "indexes" + +-- === Main FROM: pg_class (tables and views) joined with pg_namespace (schemas) === +FROM "pg_catalog"."pg_class" AS "cls" +INNER JOIN "pg_catalog"."pg_namespace" AS "ns" ON "cls"."relnamespace" = "ns"."oid" +WHERE + "ns"."nspname" !~ '^pg_' -- exclude PostgreSQL internal schemas (pg_catalog, pg_toast, etc.) + AND "ns"."nspname" != 'information_schema' -- exclude the information_schema + AND "cls"."relkind" IN ('r', 'v') -- only tables ('r') and views ('v') + AND "cls"."relname" !~ '^pg_' -- exclude system tables starting with pg_ + AND "cls"."relname" !~ '_prisma_migrations' -- exclude Prisma migration tracking table + ORDER BY "ns"."nspname", "cls"."relname" ASC; +`; + +function typeCastingConvert({defaultValue, enums, val, services}:{val: string, enums: Enum[], defaultValue:string, services:ZModelServices}): ((builder: ExpressionBuilder) => AstFactory) | null { + const [value, type] = val + .replace(/'/g, '') + .split('::') + .map((s) => s.trim()) as [string, string]; + switch (type) { + case 'character varying': + case 'uuid': + case 'json': + case 'jsonb': + case 'text': + if (value === 'NULL') return null; + return (ab) => ab.StringLiteral.setValue(value); + case 'real': + return (ab) => ab.NumberLiteral.setValue(value); + default: { + const enumDef = enums.find((e) => getDbName(e, true) === type); + if (!enumDef) { + return (ab) => + ab.InvocationExpr.setFunction(getFunctionRef('dbgenerated', services)).addArg((a) => + a.setValue((v) => v.StringLiteral.setValue(val)), + ); + } + const enumField = enumDef.fields.find((v) => getDbName(v) === value); + if (!enumField) { + throw new CliError( + `Enum value ${value} not found in enum ${type} for default value ${defaultValue}`, + ); + } + return (ab) => ab.ReferenceExpr.setTarget(enumField); + } + } +} diff --git a/packages/cli/src/actions/pull/provider/provider.ts b/packages/cli/src/actions/pull/provider/provider.ts new file mode 100644 index 000000000..7b3127132 --- /dev/null +++ b/packages/cli/src/actions/pull/provider/provider.ts @@ -0,0 +1,96 @@ +import type { ZModelServices } from '@zenstackhq/language'; +import type { BuiltinType, Enum, Expression } from '@zenstackhq/language/ast'; +import type { AstFactory, DataFieldAttributeFactory, ExpressionBuilder } from '@zenstackhq/language/factory'; + +export type Cascade = 'NO ACTION' | 'RESTRICT' | 'CASCADE' | 'SET NULL' | 'SET DEFAULT' | null; + +export interface IntrospectedTable { + schema: string; + name: string; + type: 'table' | 'view'; + definition: string | null; + columns: { + name: string; + datatype: string; + datatype_name: string | null; + length: number | null; + precision: number | null; + datatype_schema: string; + foreign_key_schema: string | null; + foreign_key_table: string | null; + foreign_key_column: string | null; + foreign_key_name: string | null; + foreign_key_on_update: Cascade; + foreign_key_on_delete: Cascade; + pk: boolean; + computed: boolean; + nullable: boolean; + unique: boolean; + unique_name: string | null; + default: string | null; + }[]; + indexes: { + name: string; + method: string | null; + unique: boolean; + primary: boolean; + valid: boolean; + ready: boolean; + partial: boolean; + predicate: string | null; + columns: { + name: string; + expression: string | null; + order: 'ASC' | 'DESC' | null; + nulls: string | null; + }[]; + }[]; +} + +export type IntrospectedEnum = { + schema_name: string; + enum_type: string; + values: string[]; +}; + +export type IntrospectedSchema = { + tables: IntrospectedTable[]; + enums: IntrospectedEnum[]; +}; + +export type DatabaseFeature = 'Schema' | 'NativeEnum'; + +export interface IntrospectionProvider { + introspect(connectionString: string, options: { schemas: string[]; modelCasing: 'pascal' | 'camel' | 'snake' | 'none' }): Promise; + getBuiltinType(type: string): { + type: BuiltinType | 'Unsupported'; + isArray: boolean; + }; + getDefaultDatabaseType(type: BuiltinType): { precision?: number; type: string } | undefined; + /** + * Get the expression builder callback for a field's @default attribute value. + * Returns null if no @default attribute should be added. + * The callback will be passed to DataFieldAttributeFactory.addArg(). + */ + getDefaultValue(args: { + fieldType: BuiltinType | 'Unsupported'; + datatype: string; + datatype_name: string | null; + defaultValue: string; + services: ZModelServices; + enums: Enum[]; + }): ((builder: ExpressionBuilder) => AstFactory) | null; + /** + * Get additional field attributes based on field type and name (e.g., @updatedAt for DateTime fields, @db.* attributes). + * This is separate from getDefaultValue to keep concerns separated. + */ + getFieldAttributes(args: { + fieldName: string; + fieldType: BuiltinType | 'Unsupported'; + datatype: string; + length: number | null; + precision: number | null; + services: ZModelServices; + }): DataFieldAttributeFactory[]; + isSupportedFeature(feature: DatabaseFeature): boolean; +} diff --git a/packages/cli/src/actions/pull/provider/sqlite.ts b/packages/cli/src/actions/pull/provider/sqlite.ts new file mode 100644 index 000000000..c4b06f367 --- /dev/null +++ b/packages/cli/src/actions/pull/provider/sqlite.ts @@ -0,0 +1,481 @@ +import { DataFieldAttributeFactory } from '@zenstackhq/language/factory'; +import { getAttributeRef, getDbName, getFunctionRef, normalizeDecimalDefault, normalizeFloatDefault } from '../utils'; +import type { IntrospectedEnum, IntrospectedSchema, IntrospectedTable, IntrospectionProvider } from './provider'; + +// Note: We dynamically import better-sqlite3 inside the async function to avoid +// requiring it at module load time for environments that don't use SQLite. + +export const sqlite: IntrospectionProvider = { + isSupportedFeature(feature) { + switch (feature) { + case 'Schema': + // Multi-schema feature is not available for SQLite because it doesn't have + // the same concept of schemas as namespaces (unlike PostgreSQL, CockroachDB, SQL Server). + return false; + case 'NativeEnum': + // SQLite doesn't support native enum types + return false; + default: + return false; + } + }, + getBuiltinType(type) { + // Strip parenthesized constraints (e.g., VARCHAR(255) → varchar, DECIMAL(10,2) → decimal) + const t = (type || '').toLowerCase().trim().replace(/\(.*\)$/, '').trim(); + // SQLite has no array types + const isArray = false; + + // SQLite type affinity rules (https://www.sqlite.org/datatype3.html): + // 1. If type contains "INT" → INTEGER affinity + // 2. If type contains "CHAR", "CLOB", or "TEXT" → TEXT affinity + // 3. If type contains "BLOB" or no type → BLOB affinity + // 4. If type contains "REAL", "FLOA", or "DOUB" → REAL affinity + // 5. Otherwise → NUMERIC affinity + + // Handle specific known types first for better mapping + switch (t) { + // INTEGER types (SQLite: INT, INTEGER, TINYINT, SMALLINT, MEDIUMINT, INT2, INT8) + case 'integer': + case 'int': + case 'tinyint': + case 'smallint': + case 'mediumint': + case 'int2': + case 'int8': + return { type: 'Int', isArray }; + + // BIGINT - map to BigInt for large integers + case 'bigint': + case 'unsigned big int': + return { type: 'BigInt', isArray }; + + // TEXT types (SQLite: CHARACTER, VARCHAR, VARYING CHARACTER, NCHAR, NATIVE CHARACTER, NVARCHAR, TEXT, CLOB) + case 'text': + case 'varchar': + case 'char': + case 'character': + case 'varying character': + case 'nchar': + case 'native character': + case 'nvarchar': + case 'clob': + return { type: 'String', isArray }; + + // BLOB type + case 'blob': + return { type: 'Bytes', isArray }; + + // REAL types (SQLite: REAL, DOUBLE, DOUBLE PRECISION, FLOAT) + case 'real': + case 'float': + case 'double': + case 'double precision': + return { type: 'Float', isArray }; + + // NUMERIC types (SQLite: NUMERIC, DECIMAL) + case 'numeric': + case 'decimal': + return { type: 'Decimal', isArray }; + + // DateTime types + case 'datetime': + case 'date': + case 'time': + case 'timestamp': + return { type: 'DateTime', isArray }; + + // JSON types + case 'json': + case 'jsonb': + return { type: 'Json', isArray }; + + // Boolean types + case 'boolean': + case 'bool': + return { type: 'Boolean', isArray }; + + default: { + // SQLite affinity rule #3: columns with no declared type have BLOB affinity + if (!t) { + return { type: 'Bytes', isArray }; + } + // Fallback: Use SQLite affinity rules for unknown types + if (t.includes('int')) { + return { type: 'Int', isArray }; + } + if (t.includes('char') || t.includes('clob') || t.includes('text')) { + return { type: 'String', isArray }; + } + if (t.includes('blob')) { + return { type: 'Bytes', isArray }; + } + if (t.includes('real') || t.includes('floa') || t.includes('doub')) { + return { type: 'Float', isArray }; + } + // Default to Unsupported for truly unknown types + return { type: 'Unsupported' as const, isArray }; + } + } + }, + + getDefaultDatabaseType() { + return undefined; + }, + + async introspect(connectionString: string, _options: { schemas: string[]; modelCasing: 'pascal' | 'camel' | 'snake' | 'none' }): Promise { + const SQLite = (await import('better-sqlite3')).default; + const db = new SQLite(connectionString, { readonly: true }); + + try { + const all = (sql: string): T[] => { + const stmt: any = db.prepare(sql); + return stmt.all() as T[]; + }; + + // List user tables and views from sqlite_schema (the master catalog). + // sqlite_schema contains one row per table, view, index, and trigger. + // We filter to only tables/views and exclude internal sqlite_* objects. + // The 'sql' column contains the original CREATE TABLE/VIEW statement. + const tablesRaw = all<{ name: string; type: 'table' | 'view'; definition: string | null }>( + "SELECT name, type, sql AS definition FROM sqlite_schema WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' ORDER BY name", + ); + + // Detect AUTOINCREMENT by parsing the CREATE TABLE statement + // The sqlite_sequence table only has entries after rows are inserted, + // so we need to check the actual table definition instead + const autoIncrementTables = new Set(); + for (const t of tablesRaw) { + if (t.type === 'table' && t.definition) { + // AUTOINCREMENT keyword appears in PRIMARY KEY definition + // e.g., PRIMARY KEY("id" AUTOINCREMENT) or PRIMARY KEY(id AUTOINCREMENT) + if (/\bAUTOINCREMENT\b/i.test(t.definition)) { + autoIncrementTables.add(t.name); + } + } + } + + const tables: IntrospectedTable[] = []; + + for (const t of tablesRaw) { + const tableName = t.name; + const schema = ''; + + // Check if this table has autoincrement (detected by parsing the CREATE TABLE DDL) + const hasAutoIncrement = autoIncrementTables.has(tableName); + + // PRAGMA table_xinfo: extended version of table_info that also includes hidden/generated columns. + // Returns one row per column with: cid (column index), name, type, notnull, dflt_value, pk. + // hidden: 0 = normal, 1 = hidden (virtual table), 2 = generated VIRTUAL, 3 = generated STORED. + const columnsInfo = all<{ + cid: number; + name: string; + type: string; + notnull: number; + dflt_value: string | null; + pk: number; + hidden?: number; + }>(`PRAGMA table_xinfo('${tableName.replace(/'/g, "''")}')`); + + // PRAGMA index_list: returns all indexes on a table. + // Each row has: seq (index sequence), name, unique (1 if unique), origin ('c'=CREATE INDEX, + // 'u'=UNIQUE constraint, 'pk'=PRIMARY KEY), partial (1 if partial index). + // We exclude sqlite_autoindex_* entries which are auto-generated for UNIQUE constraints. + const tableNameEsc = tableName.replace(/'/g, "''"); + const idxList = all<{ + seq: number; + name: string; + unique: number; + origin: string; + partial: number; + }>(`PRAGMA index_list('${tableNameEsc}')`).filter((r) => !r.name.startsWith('sqlite_autoindex_')); + + // Detect single-column unique constraints by inspecting each unique index. + // PRAGMA index_info: returns the columns that make up an index. + // If a unique (non-partial) index has exactly one column, that column is "unique". + const uniqueSingleColumn = new Set(); + const uniqueIndexRows = idxList.filter((r) => r.unique === 1 && r.partial !== 1); + for (const idx of uniqueIndexRows) { + const idxCols = all<{ name: string }>(`PRAGMA index_info('${idx.name.replace(/'/g, "''")}')`); + if (idxCols.length === 1 && idxCols[0]?.name) { + uniqueSingleColumn.add(idxCols[0].name); + } + } + + // Build detailed index info for each index. + // PRAGMA index_info returns one row per column in the index. + // SQLite doesn't expose access method, predicate, or sort order through PRAGMAs. + const indexes: IntrospectedTable['indexes'] = idxList.map((idx) => { + const idxCols = all<{ name: string }>(`PRAGMA index_info('${idx.name.replace(/'/g, "''")}')`); + return { + name: idx.name, + method: null, // SQLite does not expose index method + unique: idx.unique === 1, + primary: false, // SQLite does not expose this directly; handled via pk in columns + valid: true, // SQLite does not expose index validity + ready: true, // SQLite does not expose index readiness + partial: idx.partial === 1, + predicate: idx.partial === 1 ? '[partial]' : null, // SQLite does not expose index predicate + columns: idxCols.map((col) => ({ + name: col.name, + expression: null, + order: null, + nulls: null, + })), + }; + }); + + // PRAGMA foreign_key_list: returns all foreign key constraints on a table. + // Each row represents one column in a FK constraint with: id (FK id, shared by multi-column FKs), + // seq (column index within the FK), table (referenced table), from (local column), + // to (referenced column), on_update, on_delete (referential actions). + const fkRows = all<{ + id: number; + seq: number; + table: string; + from: string; + to: string | null; + on_update: any; + on_delete: any; + }>(`PRAGMA foreign_key_list('${tableName.replace(/'/g, "''")}')`); + + // Extract FK constraint names from CREATE TABLE statement. + // Captures the constraint name and the full parenthesized column list from + // FOREIGN KEY(...), then splits and parses individual column names so that + // composite FKs (e.g., FOREIGN KEY("col1", "col2")) are handled correctly. + const fkConstraintNames = new Map(); + if (t.definition) { + // Match: CONSTRAINT "name" FOREIGN KEY() + // Group 1/2: quoted/unquoted constraint name + // Group 3: the full content inside FOREIGN KEY(...) + const fkRegex = /CONSTRAINT\s+(?:["'`]([^"'`]+)["'`]|(\w+))\s+FOREIGN\s+KEY\s*\(([^)]+)\)/gi; + let match; + while ((match = fkRegex.exec(t.definition)) !== null) { + const constraintName = match[1] || match[2]; + const columnList = match[3]; + if (constraintName && columnList) { + // Split the column list on commas and strip quotes/whitespace + // to extract each individual column name. + const columns = columnList.split(',').map((col) => col.trim().replace(/^["'`]|["'`]$/g, '')); + for (const col of columns) { + if (col) { + fkConstraintNames.set(col, constraintName); + } + } + } + } + } + + const fkByColumn = new Map< + string, + { + foreign_key_schema: string | null; + foreign_key_table: string | null; + foreign_key_column: string | null; + foreign_key_name: string | null; + foreign_key_on_update: IntrospectedTable['columns'][number]['foreign_key_on_update']; + foreign_key_on_delete: IntrospectedTable['columns'][number]['foreign_key_on_delete']; + } + >(); + + for (const fk of fkRows) { + fkByColumn.set(fk.from, { + foreign_key_schema: '', + foreign_key_table: fk.table || null, + foreign_key_column: fk.to || null, + foreign_key_name: fkConstraintNames.get(fk.from) ?? null, + foreign_key_on_update: (fk.on_update as any) ?? null, + foreign_key_on_delete: (fk.on_delete as any) ?? null, + }); + } + + // Pre-extract full column type definitions from DDL for generated columns. + // PRAGMA table_xinfo only returns the base type (e.g., "TEXT"), but for + // generated columns we need the full definition including the expression + // (e.g., "TEXT GENERATED ALWAYS AS (...) STORED") so they are pulled as + // Unsupported("...") — matching Prisma's introspection behavior. + const generatedColDefs = t.definition ? extractColumnTypeDefs(t.definition) : new Map(); + + const columns: IntrospectedTable['columns'] = []; + for (const c of columnsInfo) { + // hidden: 0 = normal, 1 = hidden (virtual table) → skip, + // 2 = generated VIRTUAL, 3 = generated STORED → mark computed + const hidden = c.hidden ?? 0; + if (hidden === 1) continue; + + const isGenerated = hidden === 2 || hidden === 3; + + const fk = fkByColumn.get(c.name); + + // Determine default value - check for autoincrement + // AUTOINCREMENT in SQLite can only be on INTEGER PRIMARY KEY column + let defaultValue = c.dflt_value; + if (hasAutoIncrement && c.pk) { + defaultValue = 'autoincrement'; + } + + // For generated columns, use the full DDL type definition so that + // getBuiltinType returns Unsupported and the column is rendered as + // Unsupported("TYPE GENERATED ALWAYS AS (...) STORED/VIRTUAL"). + let datatype = c.type || ''; + if (isGenerated) { + const fullDef = generatedColDefs.get(c.name); + if (fullDef) { + datatype = fullDef; + } + } + + columns.push({ + name: c.name, + datatype, + datatype_name: null, // SQLite doesn't support native enums + length: null, + precision: null, + datatype_schema: schema, + foreign_key_schema: fk?.foreign_key_schema ?? null, + foreign_key_table: fk?.foreign_key_table ?? null, + foreign_key_column: fk?.foreign_key_column ?? null, + foreign_key_name: fk?.foreign_key_name ?? null, + foreign_key_on_update: fk?.foreign_key_on_update ?? null, + foreign_key_on_delete: fk?.foreign_key_on_delete ?? null, + pk: !!c.pk, + computed: isGenerated, + nullable: c.notnull !== 1, + default: defaultValue, + unique: uniqueSingleColumn.has(c.name), + unique_name: null, + }); + } + + tables.push({ schema, name: tableName, columns, type: t.type, definition: t.definition, indexes }); + } + + const enums: IntrospectedEnum[] = []; // SQLite doesn't support enums + + return { tables, enums }; + } finally { + db.close(); + } + }, + + getDefaultValue({ defaultValue, fieldType, services, enums }) { // datatype and datatype_name not used for SQLite + const val = defaultValue.trim(); + + switch (fieldType) { + case 'DateTime': + if (val === 'CURRENT_TIMESTAMP' || val === 'now()') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('now', services)); + } + // Fallback to string literal for other DateTime defaults + return (ab) => ab.StringLiteral.setValue(val); + + case 'Int': + case 'BigInt': + if (val === 'autoincrement') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('autoincrement', services)); + } + return (ab) => ab.NumberLiteral.setValue(val); + + case 'Float': + return normalizeFloatDefault(val); + + case 'Decimal': + return normalizeDecimalDefault(val); + + case 'Boolean': + return (ab) => ab.BooleanLiteral.setValue(val === 'true' || val === '1'); + case 'String': + if (val.startsWith("'") && val.endsWith("'")) { + const strippedName = val.slice(1, -1); + const enumDef = enums.find((e) => e.fields.find((v) => getDbName(v) === strippedName)); + if (enumDef) { + const enumField = enumDef.fields.find((v) => getDbName(v) === strippedName); + if (enumField) return (ab) => ab.ReferenceExpr.setTarget(enumField); + } + return (ab) => ab.StringLiteral.setValue(strippedName); + } + return (ab) => ab.StringLiteral.setValue(val); + case 'Json': + return (ab) => ab.StringLiteral.setValue(val); + case 'Bytes': + return (ab) => ab.StringLiteral.setValue(val); + } + + console.warn(`Unsupported default value type: "${defaultValue}" for field type "${fieldType}". Skipping default value.`); + return null; + }, + + getFieldAttributes({ fieldName, fieldType, services }) { + const factories: DataFieldAttributeFactory[] = []; + + // Add @updatedAt for DateTime fields named updatedAt or updated_at + if (fieldType === 'DateTime' && (fieldName.toLowerCase() === 'updatedat' || fieldName.toLowerCase() === 'updated_at')) { + factories.push(new DataFieldAttributeFactory().setDecl(getAttributeRef('@updatedAt', services))); + } + + return factories; + }, +}; + +/** + * Extract column type definitions from a CREATE TABLE DDL statement. + * Returns a map of column name → full type definition string (everything after the column name). + * Used to get the complete type including GENERATED ALWAYS AS (...) STORED/VIRTUAL for generated columns. + */ +function extractColumnTypeDefs(ddl: string): Map { + // Find the content inside CREATE TABLE "name" ( ... ) + // Use a paren-depth approach to find the matching closing paren. + const openIdx = ddl.indexOf('('); + if (openIdx === -1) return new Map(); + + let depth = 1; + let closeIdx = -1; + for (let i = openIdx + 1; i < ddl.length; i++) { + if (ddl[i] === '(') depth++; + else if (ddl[i] === ')') { + depth--; + if (depth === 0) { + closeIdx = i; + break; + } + } + } + if (closeIdx === -1) return new Map(); + + const content = ddl.substring(openIdx + 1, closeIdx); + + // Split column definitions on commas, respecting nested parentheses. + const defs: string[] = []; + let current = ''; + depth = 0; + for (const char of content) { + if (char === '(') depth++; + else if (char === ')') depth--; + else if (char === ',' && depth === 0) { + defs.push(current.trim()); + current = ''; + continue; + } + current += char; + } + if (current.trim()) defs.push(current.trim()); + + // Map column name → type definition (everything after the column name). + // Table constraints (CONSTRAINT, PRIMARY KEY, UNIQUE, FOREIGN KEY, CHECK) + // are skipped since they don't define columns. + const result = new Map(); + for (const def of defs) { + // Match: optional quote + column name + optional quote + whitespace + type definition + const nameMatch = def.match(/^(?:["'`]([^"'`]+)["'`]|(\w+))\s+(.+)/s); + if (nameMatch) { + const name = nameMatch[1] || nameMatch[2]; + const typeDef = nameMatch[3]; + // Skip table-level constraints (they start with keywords, not column names, + // but could still match the regex — the map lookup by actual column name + // ensures they never interfere). + if (name && typeDef) { + result.set(name, typeDef.trim()); + } + } + } + return result; +} diff --git a/packages/cli/src/actions/pull/utils.ts b/packages/cli/src/actions/pull/utils.ts new file mode 100644 index 000000000..9ec056bc4 --- /dev/null +++ b/packages/cli/src/actions/pull/utils.ts @@ -0,0 +1,218 @@ +import type { ZModelServices } from '@zenstackhq/language'; +import { + type AbstractDeclaration, + type DataField, + type DataModel, + type Enum, + type EnumField, + type Expression, + type FunctionDecl, + isInvocationExpr, + type Attribute, + type Model, + type ReferenceExpr, + type StringLiteral, +} from '@zenstackhq/language/ast'; +import type { AstFactory, ExpressionBuilder } from '@zenstackhq/language/factory'; +import { getLiteralArray, getStringLiteral } from '@zenstackhq/language/utils'; +import type { DataSourceProviderType } from '@zenstackhq/schema'; +import type { Reference } from 'langium'; +import { CliError } from '../../cli-error'; + +export function getAttribute(model: Model, attrName: string) { + if (!model.$document) throw new CliError('Model is not associated with a document.'); + + const references = model.$document.references as Reference[]; + return references.find((a) => a.ref?.$type === 'Attribute' && a.ref?.name === attrName)?.ref as + | Attribute + | undefined; +} + +export function isDatabaseManagedAttribute(name: string) { + return ['@relation', '@id', '@unique'].includes(name) || name.startsWith('@db.'); +} + +export function getDatasource(model: Model) { + const datasource = model.declarations.find((d) => d.$type === 'DataSource'); + if (!datasource) { + throw new CliError('No datasource declaration found in the schema.'); + } + + const urlField = datasource.fields.find((f) => f.name === 'url'); + + if (!urlField) throw new CliError(`No url field found in the datasource declaration.`); + + let url = getStringLiteral(urlField.value); + + if (!url && isInvocationExpr(urlField.value)) { + const envName = getStringLiteral(urlField.value.args[0]?.value); + if (!envName) { + throw new CliError('The url field must be a string literal or an env().'); + } + if (!process.env[envName]) { + throw new CliError( + `Environment variable ${envName} is not set, please set it to the database connection string.`, + ); + } + url = process.env[envName]; + } + + if (!url) { + throw new CliError('The url field must be a string literal or an env().'); + } + + if (url.startsWith('file:')) { + url = new URL(url, `file:${model.$document!.uri.path}`).pathname; + if (process.platform === 'win32' && url[0] === '/') url = url.slice(1); + } + + const defaultSchemaField = datasource.fields.find((f) => f.name === 'defaultSchema'); + const defaultSchema = (defaultSchemaField && getStringLiteral(defaultSchemaField.value)) || 'public'; + + const schemasField = datasource.fields.find((f) => f.name === 'schemas'); + const schemas = + (schemasField && + getLiteralArray(schemasField.value) + ?.filter((s) => s !== undefined)) as string[] || + []; + + const provider = getStringLiteral( + datasource.fields.find((f) => f.name === 'provider')?.value, + ); + if (!provider) { + throw new CliError(`Datasource "${datasource.name}" is missing a "provider" field.`); + } + + return { + name: datasource.name, + provider: provider as DataSourceProviderType, + url, + defaultSchema, + schemas, + allSchemas: [defaultSchema, ...schemas], + }; +} + +export function getDbName(decl: AbstractDeclaration | DataField | EnumField, includeSchema: boolean = false): string { + if (!('attributes' in decl)) return decl.name; + + const schemaAttr = decl.attributes.find((a) => a.decl.ref?.name === '@@schema'); + let schema = 'public'; + if (schemaAttr) { + const schemaAttrValue = schemaAttr.args[0]?.value; + if (schemaAttrValue?.$type === 'StringLiteral') { + schema = schemaAttrValue.value; + } + } + + const formatName = (name: string) => `${schema && includeSchema ? `${schema}.` : ''}${name}`; + + const nameAttr = decl.attributes.find((a) => a.decl.ref?.name === '@@map' || a.decl.ref?.name === '@map'); + if (!nameAttr) return formatName(decl.name); + const attrValue = nameAttr.args[0]?.value; + + if (attrValue?.$type !== 'StringLiteral') return formatName(decl.name); + + return formatName(attrValue.value); +} + +export function getRelationFkName(decl: DataField): string | undefined { + const relationAttr = decl?.attributes.find((a) => a.decl.ref?.name === '@relation'); + const schemaAttrValue = relationAttr?.args.find((a) => a.name === 'map')?.value as StringLiteral; + return schemaAttrValue?.value; +} + +/** + * Gets the FK field names from the @relation attribute's `fields` argument. + * Returns a sorted, comma-separated string of field names for comparison. + * e.g., @relation(fields: [userId], references: [id]) -> "userId" + * e.g., @relation(fields: [postId, tagId], references: [id, id]) -> "postId,tagId" + */ +export function getRelationFieldsKey(decl: DataField): string | undefined { + const relationAttr = decl?.attributes.find((a) => a.decl.ref?.name === '@relation'); + if (!relationAttr) return undefined; + + const fieldsArg = relationAttr.args.find((a) => a.name === 'fields')?.value; + if (!fieldsArg || fieldsArg.$type !== 'ArrayExpr') return undefined; + + const fieldNames = fieldsArg.items + .filter((item): item is ReferenceExpr => item.$type === 'ReferenceExpr') + .map((item) => item.target?.$refText || item.target?.ref?.name) + .filter((name): name is string => !!name) + .sort(); + + return fieldNames.length > 0 ? fieldNames.join(',') : undefined; +} + +export function getDbSchemaName(decl: DataModel | Enum): string { + const schemaAttr = decl.attributes.find((a) => a.decl.ref?.name === '@@schema'); + if (!schemaAttr) return 'public'; + const attrValue = schemaAttr.args[0]?.value; + + if (attrValue?.$type !== 'StringLiteral') return 'public'; + + return attrValue.value; +} + +export function getDeclarationRef( + type: T['$type'], + name: string, + services: ZModelServices, +) { + const node = services.shared.workspace.IndexManager.allElements(type).find( + (m) => m.node && getDbName(m.node as T) === name, + )?.node; + if (!node) throw new CliError(`Declaration not found: ${name}`); + return node as T; +} + +export function getEnumRef(name: string, services: ZModelServices) { + return getDeclarationRef('Enum', name, services); +} + +export function getModelRef(name: string, services: ZModelServices) { + return getDeclarationRef('DataModel', name, services); +} + +export function getAttributeRef(name: string, services: ZModelServices) { + return getDeclarationRef('Attribute', name, services); +} + +export function getFunctionRef(name: string, services: ZModelServices) { + return getDeclarationRef('FunctionDecl', name, services); +} + +/** + * Normalize a default value string for a Float field. + * - Integer strings get `.0` appended + * - Decimal strings are preserved as-is + */ +export function normalizeFloatDefault(val: string): (ab: ExpressionBuilder) => AstFactory { + if (/^-?\d+$/.test(val)) { + return (ab) => ab.NumberLiteral.setValue(val + '.0'); + } + if (/^-?\d+\.\d+$/.test(val)) { + return (ab) => ab.NumberLiteral.setValue(val); + } + return (ab) => ab.NumberLiteral.setValue(val); +} + +/** + * Normalize a default value string for a Decimal field. + * - Integer strings get `.00` appended + * - Decimal strings are normalized to minimum 2 decimal places, stripping excess trailing zeros + */ +export function normalizeDecimalDefault(val: string): (ab: ExpressionBuilder) => AstFactory { + if (/^-?\d+$/.test(val)) { + return (ab) => ab.NumberLiteral.setValue(val + '.00'); + } + if (/^-?\d+\.\d+$/.test(val)) { + const [integerPart, fractionalPart] = val.split('.'); + let normalized = fractionalPart!.replace(/0+$/, ''); + if (normalized.length < 2) { + normalized = normalized.padEnd(2, '0'); + } + return (ab) => ab.NumberLiteral.setValue(`${integerPart}.${normalized}`); + } + return (ab) => ab.NumberLiteral.setValue(val); +} diff --git a/packages/cli/src/index.ts b/packages/cli/src/index.ts index 4efc86fd9..bc52a9803 100644 --- a/packages/cli/src/index.ts +++ b/packages/cli/src/index.ts @@ -143,6 +143,36 @@ function createProgram() { .addOption(new Option('--force-reset', 'force a reset of the database before push')) .action((options) => dbAction('push', options)); + dbCommand + .command('pull') + .description('Introspect your database.') + .addOption(schemaOption) + .addOption(noVersionCheckOption) + .addOption( + new Option( + '-o, --output ', + 'set custom output path for the introspected schema. If a file path is provided, all schemas are merged into that single file. If a directory path is provided, files are written to the directory and imports are kept.', + ), + ) + .addOption( + new Option('--model-casing ', 'set the casing of generated models').default( + 'pascal', + ), + ) + .addOption( + new Option('--field-casing ', 'set the casing of generated fields').default( + 'camel', + ), + ) + .addOption( + new Option('--always-map', 'always add @map and @@map attributes to models and fields').default(false), + ) + .addOption( + new Option('--quote ', 'set the quote style of generated schema files').default('single'), + ) + .addOption(new Option('--indent ', 'set the indentation of the generated schema files').default(4)) + .action((options) => dbAction('pull', options)); + dbCommand .command('seed') .description('Seed the database') diff --git a/packages/cli/test/casing.test.ts b/packages/cli/test/casing.test.ts new file mode 100644 index 000000000..60b9d1c88 --- /dev/null +++ b/packages/cli/test/casing.test.ts @@ -0,0 +1,130 @@ +import { describe, expect, it } from 'vitest'; +import { resolveNameCasing, toPascalCase, toCamelCase, toSnakeCase } from '../src/actions/pull/casing'; + +describe('toPascalCase', () => { + it('converts snake_case', () => { + expect(toPascalCase('user_status')).toBe('UserStatus'); + expect(toPascalCase('first_name')).toBe('FirstName'); + }); + + it('converts kebab-case', () => { + expect(toPascalCase('user-status')).toBe('UserStatus'); + }); + + it('capitalizes first char of lowercase', () => { + expect(toPascalCase('user')).toBe('User'); + }); + + it('preserves already PascalCase', () => { + expect(toPascalCase('UserStatus')).toBe('UserStatus'); + }); + + it('preserves all-uppercase strings', () => { + expect(toPascalCase('ACTIVE')).toBe('ACTIVE'); + expect(toPascalCase('USER')).toBe('USER'); + expect(toPascalCase('MODERATOR')).toBe('MODERATOR'); + expect(toPascalCase('SET_NULL')).toBe('SET_NULL'); + expect(toPascalCase('NO_ACTION')).toBe('NO_ACTION'); + }); + + it('converts mixed snake_case with uppercase', () => { + expect(toPascalCase('User_status')).toBe('UserStatus'); + }); +}); + +describe('toCamelCase', () => { + it('converts snake_case', () => { + expect(toCamelCase('user_status')).toBe('userStatus'); + expect(toCamelCase('first_name')).toBe('firstName'); + }); + + it('converts kebab-case', () => { + expect(toCamelCase('user-status')).toBe('userStatus'); + }); + + it('lowercases first char of PascalCase', () => { + expect(toCamelCase('User')).toBe('user'); + expect(toCamelCase('Post')).toBe('post'); + }); + + it('preserves already camelCase', () => { + expect(toCamelCase('userStatus')).toBe('userStatus'); + }); + + it('preserves all-uppercase strings', () => { + expect(toCamelCase('ACTIVE')).toBe('ACTIVE'); + expect(toCamelCase('INACTIVE')).toBe('INACTIVE'); + expect(toCamelCase('SUSPENDED')).toBe('SUSPENDED'); + expect(toCamelCase('USER')).toBe('USER'); + expect(toCamelCase('SET_NULL')).toBe('SET_NULL'); + expect(toCamelCase('NO_ACTION')).toBe('NO_ACTION'); + }); +}); + +describe('toSnakeCase', () => { + it('converts camelCase', () => { + expect(toSnakeCase('userStatus')).toBe('user_status'); + expect(toSnakeCase('firstName')).toBe('first_name'); + }); + + it('converts PascalCase', () => { + expect(toSnakeCase('UserStatus')).toBe('user_status'); + }); + + it('converts kebab-case', () => { + expect(toSnakeCase('user-status')).toBe('user_status'); + }); + + it('preserves already snake_case', () => { + expect(toSnakeCase('user_status')).toBe('user_status'); + }); + + it('preserves all-uppercase strings', () => { + expect(toSnakeCase('ACTIVE')).toBe('ACTIVE'); + expect(toSnakeCase('INACTIVE')).toBe('INACTIVE'); + expect(toSnakeCase('SUSPENDED')).toBe('SUSPENDED'); + expect(toSnakeCase('SET_NULL')).toBe('SET_NULL'); + expect(toSnakeCase('NO_ACTION')).toBe('NO_ACTION'); + }); +}); + +describe('resolveNameCasing', () => { + it('applies pascal casing', () => { + expect(resolveNameCasing('pascal', 'user_status')).toEqual({ modified: true, name: 'UserStatus' }); + expect(resolveNameCasing('pascal', 'User')).toEqual({ modified: false, name: 'User' }); + }); + + it('applies camel casing', () => { + expect(resolveNameCasing('camel', 'User')).toEqual({ modified: true, name: 'user' }); + expect(resolveNameCasing('camel', 'first_name')).toEqual({ modified: true, name: 'firstName' }); + }); + + it('applies snake casing', () => { + expect(resolveNameCasing('snake', 'UserStatus')).toEqual({ modified: true, name: 'user_status' }); + expect(resolveNameCasing('snake', 'user_status')).toEqual({ modified: false, name: 'user_status' }); + }); + + it('preserves name with none casing', () => { + expect(resolveNameCasing('none', 'User_status')).toEqual({ modified: false, name: 'User_status' }); + expect(resolveNameCasing('none', 'ACTIVE')).toEqual({ modified: false, name: 'ACTIVE' }); + }); + + it('preserves all-uppercase enum values across all casings', () => { + expect(resolveNameCasing('pascal', 'ACTIVE')).toEqual({ modified: false, name: 'ACTIVE' }); + expect(resolveNameCasing('camel', 'ACTIVE')).toEqual({ modified: false, name: 'ACTIVE' }); + expect(resolveNameCasing('snake', 'ACTIVE')).toEqual({ modified: false, name: 'ACTIVE' }); + expect(resolveNameCasing('none', 'ACTIVE')).toEqual({ modified: false, name: 'ACTIVE' }); + }); + + it('preserves all-uppercase enum values with underscores across all casings', () => { + expect(resolveNameCasing('pascal', 'SET_NULL')).toEqual({ modified: false, name: 'SET_NULL' }); + expect(resolveNameCasing('camel', 'SET_NULL')).toEqual({ modified: false, name: 'SET_NULL' }); + expect(resolveNameCasing('snake', 'SET_NULL')).toEqual({ modified: false, name: 'SET_NULL' }); + expect(resolveNameCasing('none', 'SET_NULL')).toEqual({ modified: false, name: 'SET_NULL' }); + }); + + it('prefixes names starting with a digit', () => { + expect(resolveNameCasing('none', '1foo')).toEqual({ modified: true, name: '_1foo' }); + expect(resolveNameCasing('camel', '1foo')).toEqual({ modified: true, name: '_1foo' }); + }); +}); diff --git a/packages/cli/test/check.test.ts b/packages/cli/test/check.test.ts index 287bb6b80..99d31ecda 100644 --- a/packages/cli/test/check.test.ts +++ b/packages/cli/test/check.test.ts @@ -36,37 +36,37 @@ model Post { `; describe('CLI validate command test', () => { - it('should validate a valid schema successfully', () => { - const workDir = createProject(validModel); + it('should validate a valid schema successfully', async () => { + const { workDir } = await createProject(validModel); // Should not throw an error expect(() => runCli('check', workDir)).not.toThrow(); }); - it('should fail validation for invalid schema', () => { - const workDir = createProject(invalidModel); + it('should fail validation for invalid schema', async () => { + const { workDir } = await createProject(invalidModel); // Should throw an error due to validation failure expect(() => runCli('check', workDir)).toThrow(); }); - it('should respect custom schema location', () => { - const workDir = createProject(validModel); + it('should respect custom schema location', async () => { + const { workDir } = await createProject(validModel); fs.renameSync(path.join(workDir, 'zenstack/schema.zmodel'), path.join(workDir, 'zenstack/custom.zmodel')); // Should not throw an error when using custom schema path expect(() => runCli('check --schema ./zenstack/custom.zmodel', workDir)).not.toThrow(); }); - it('should fail when schema file does not exist', () => { - const workDir = createProject(validModel); + it('should fail when schema file does not exist', async () => { + const { workDir } = await createProject(validModel); // Should throw an error when schema file doesn't exist expect(() => runCli('check --schema ./nonexistent.zmodel', workDir)).toThrow(); }); - it('should respect package.json config', () => { - const workDir = createProject(validModel); + it('should respect package.json config', async () => { + const { workDir } = await createProject(validModel); fs.mkdirSync(path.join(workDir, 'foo')); fs.renameSync(path.join(workDir, 'zenstack/schema.zmodel'), path.join(workDir, 'foo/schema.zmodel')); fs.rmdirSync(path.join(workDir, 'zenstack')); @@ -81,19 +81,14 @@ describe('CLI validate command test', () => { expect(() => runCli('check', workDir)).not.toThrow(); }); - it('should validate schema with syntax errors', () => { + it('should validate schema with syntax errors', async () => { const modelWithSyntaxError = ` -datasource db { - provider = "sqlite" - url = "file:./dev.db" -} - model User { id String @id @default(cuid()) email String @unique // Missing closing brace - syntax error `; - const workDir = createProject(modelWithSyntaxError, false); + const { workDir } = await createProject(modelWithSyntaxError); // Should throw an error due to syntax error expect(() => runCli('check', workDir)).toThrow(); diff --git a/packages/cli/test/db.test.ts b/packages/cli/test/db.test.ts index 636dcff8f..b5b76d4fa 100644 --- a/packages/cli/test/db.test.ts +++ b/packages/cli/test/db.test.ts @@ -10,14 +10,14 @@ model User { `; describe('CLI db commands test', () => { - it('should generate a database with db push', () => { - const workDir = createProject(model); + it('should generate a database with db push', async () => { + const { workDir } = await createProject(model, { provider: 'sqlite' }); runCli('db push', workDir); - expect(fs.existsSync(path.join(workDir, 'zenstack/dev.db'))).toBe(true); + expect(fs.existsSync(path.join(workDir, 'zenstack/test.db'))).toBe(true); }); - it('should seed the database with db seed with seed script', () => { - const workDir = createProject(model); + it('should seed the database with db seed with seed script', async () => { + const { workDir } = await createProject(model, { provider: 'sqlite' }); const pkgJson = JSON.parse(fs.readFileSync(path.join(workDir, 'package.json'), 'utf8')); pkgJson.zenstack = { seed: 'node seed.js', @@ -35,8 +35,8 @@ fs.writeFileSync('seed.txt', 'success'); expect(fs.readFileSync(path.join(workDir, 'seed.txt'), 'utf8')).toBe('success'); }); - it('should seed the database after migrate reset', () => { - const workDir = createProject(model); + it('should seed the database after migrate reset', async () => { + const { workDir } = await createProject(model, { provider: 'sqlite' }); const pkgJson = JSON.parse(fs.readFileSync(path.join(workDir, 'package.json'), 'utf8')); pkgJson.zenstack = { seed: 'node seed.js', @@ -54,8 +54,8 @@ fs.writeFileSync('seed.txt', 'success'); expect(fs.readFileSync(path.join(workDir, 'seed.txt'), 'utf8')).toBe('success'); }); - it('should skip seeding the database without seed script', () => { - const workDir = createProject(model); + it('should skip seeding the database without seed script', async () => { + const { workDir } = await createProject(model, { provider: 'sqlite' }); runCli('db seed', workDir); }); }); diff --git a/packages/cli/test/db/pull.test.ts b/packages/cli/test/db/pull.test.ts new file mode 100644 index 000000000..2750a2228 --- /dev/null +++ b/packages/cli/test/db/pull.test.ts @@ -0,0 +1,1366 @@ +import fs from 'node:fs'; +import path from 'node:path'; +import { describe, expect, it } from 'vitest'; +import { createProject, getDefaultPrelude, getTestDbName, getTestDbUrl, runCli } from '../utils'; +import { formatDocument } from '@zenstackhq/language'; +import { getTestDbProvider } from '@zenstackhq/testtools'; + +const getSchema = (workDir: string) => fs.readFileSync(path.join(workDir, 'zenstack/schema.zmodel')).toString(); + +describe('DB pull - Common features (all providers)', () => { + describe('Pull from zero - restore complete schema from database', () => { + it('should restore basic schema with all supported types', async () => { + const { workDir, schema } = await createProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique + name String? + age Int @default(0) + balance Decimal @default(0.00) + isActive Boolean @default(true) + bigCounter BigInt @default(0) + score Float @default(0.0) + bio String? + avatar Bytes? + metadata Json? + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt +}`, + ); + runCli('db push', workDir); + + // Store the schema after db push (this is what provider names will be) + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + + // Remove schema content to simulate restoration from zero + fs.writeFileSync(schemaFile, getDefaultPrelude()); + + // Pull should fully restore the schema + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(schema); + }); + + it('should restore schema with relations', async () => { + const { workDir, schema } = await createProject( + `model Post { + id Int @id @default(autoincrement()) + title String + author User @relation(fields: [authorId], references: [id], onDelete: Cascade) + authorId Int +} + +model User { + id Int @id @default(autoincrement()) + email String @unique + posts Post[] +}`, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + + fs.writeFileSync(schemaFile, getDefaultPrelude()); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(schema); + }); + + it('should restore schema with many-to-many relations', async () => { + const { workDir, schema } = await createProject( + `model Post { + id Int @id @default(autoincrement()) + title String + postTags PostTag[] +} + +model PostTag { + post Post @relation(fields: [postId], references: [id], onDelete: Cascade) + postId Int + tag Tag @relation(fields: [tagId], references: [id], onDelete: Cascade) + tagId Int + + @@id([postId, tagId]) +} + +model Tag { + id Int @id @default(autoincrement()) + name String @unique + postTags PostTag[] +}`, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + + fs.writeFileSync(schemaFile, getDefaultPrelude()); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(schema); + }); + + it('should restore self-referencing model with multiple FK columns without duplicate fields', async () => { + const { workDir, schema } = await createProject( + `model Category { + id Int @id @default(autoincrement()) + categoryParentId Category? @relation('Category_parentIdToCategory', fields: [parentId], references: [id]) + parentId Int? + categoryBuddyId Category? @relation('Category_buddyIdToCategory', fields: [buddyId], references: [id]) + buddyId Int? + categoryMentorId Category? @relation('Category_mentorIdToCategory', fields: [mentorId], references: [id]) + mentorId Int? + categoryParentIdToCategoryId Category[] @relation('Category_parentIdToCategory') + categoryBuddyIdToCategoryId Category[] @relation('Category_buddyIdToCategory') + categoryMentorIdToCategoryId Category[] @relation('Category_mentorIdToCategory') +}`, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + + fs.writeFileSync(schemaFile, getDefaultPrelude()); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + + expect(restoredSchema).toEqual(schema); + }); + + it('should preserve self-referencing model with multiple FK columns', async () => { + const { workDir, schema } = await createProject( + `model Category { + id Int @id @default(autoincrement()) + category Category? @relation('Category_parentIdToCategory', fields: [parentId], references: [id]) + parentId Int? + buddy Category? @relation('Category_buddyIdToCategory', fields: [buddyId], references: [id]) + buddyId Int? + mentor Category? @relation('Category_mentorIdToCategory', fields: [mentorId], references: [id]) + mentorId Int? + categories Category[] @relation('Category_parentIdToCategory') + buddys Category[] @relation('Category_buddyIdToCategory') + mentees Category[] @relation('Category_mentorIdToCategory') +}`, + ); + runCli('db push', workDir); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + + expect(restoredSchema).toEqual(schema); + }); + + it('should restore one-to-one relation when FK is the single-column primary key', async () => { + const { workDir, schema } = await createProject( + `model Profile { + user User @relation(fields: [id], references: [id], onDelete: Cascade) + id Int @id @default(autoincrement()) + bio String? +} + +model User { + id Int @id @default(autoincrement()) + email String @unique + profile Profile? +}`, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + + fs.writeFileSync(schemaFile, getDefaultPrelude()); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(schema); + }); + + it('should restore schema with indexes and unique constraints', async () => { + const { workDir, schema } = await createProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique + username String + firstName String + lastName String + role String + + @@unique([username, email]) + @@index([role]) + @@index([firstName, lastName]) + @@index([email, username, role]) +}`, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + + fs.writeFileSync(schemaFile, getDefaultPrelude()); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(schema); + }); + + it('should restore schema with composite primary keys', async () => { + const { workDir, schema } = await createProject( + `model UserRole { + userId String + role String + grantedAt DateTime @default(now()) + + @@id([userId, role]) +}`, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + + fs.writeFileSync(schemaFile, getDefaultPrelude()); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(schema); + }); + + it('should preserve Decimal and Float default value precision', async () => { + const { workDir, schema } = await createProject( + `model Product { + id Int @id @default(autoincrement()) + price Decimal @default(99.99) + discount Decimal @default(0.50) + taxRate Decimal @default(7.00) + weight Float @default(1.5) + rating Float @default(4.0) + temperature Float @default(98.6) +}`, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + + fs.writeFileSync(schemaFile, getDefaultPrelude()); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(schema); + }); + + }); + + describe('Pull with existing schema - preserve schema features', () => { + it('should preserve field and table mappings', async () => { + const { workDir, schema } = await createProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique @map('email_address') + firstName String @map('first_name') + lastName String @map('last_name') + + @@map('users') +}`, + ); + runCli('db push', workDir); + + runCli('db pull --indent 4', workDir); + + expect(getSchema(workDir)).toEqual(schema); + }); + + it('should not modify a comprehensive schema with all features', async () => { + const { workDir, schema } = await createProject(`model User { + id Int @id @default(autoincrement()) + email String @unique @map('email_address') + name String? @default('Anonymous') + role Role @default(USER) + profile Profile? + shared_profile Profile? @relation('shared') + posts Post[] + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + jsonData Json? + balance Decimal @default(0.00) + isActive Boolean @default(true) + bigCounter BigInt @default(0) + bytes Bytes? + + @@index([role]) + @@map('users') +} + +model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + userId Int @unique + user_shared User @relation('shared', fields: [shared_userId], references: [id], onDelete: Cascade) + shared_userId Int @unique + bio String? + avatarUrl String? + + @@map('profiles') +} + +model Post { + id Int @id @default(autoincrement()) + author User @relation(fields: [authorId], references: [id], onDelete: Cascade) + authorId Int + title String + content String? + published Boolean @default(false) + tags PostTag[] + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + slug String + score Float @default(0.0) + metadata Json? + + @@unique([authorId, slug]) + @@index([authorId, published]) + @@map('posts') +} + +model Tag { + id Int @id @default(autoincrement()) + name String @unique + posts PostTag[] + createdAt DateTime @default(now()) + + @@index([name], name: 'tag_name_idx') + @@map('tags') +} + +model PostTag { + post Post @relation(fields: [postId], references: [id], onDelete: Cascade) + postId Int + tag Tag @relation(fields: [tagId], references: [id], onDelete: Cascade) + tagId Int + assignedAt DateTime @default(now()) + note String? @default('initial') + + @@id([postId, tagId]) + @@map('post_tags') +} + +enum Role { + USER + ADMIN + MODERATOR +}`, +// When using MySQL, the introspection simply overrides the enum and cannot detect if it exists with the same name because it only stores the values. +// TODO: Create a better way to handle this, possibly by finding enums by their values as well if the schema exists. + ); + runCli('db push', workDir); + + runCli('db pull --indent 4', workDir); + expect(getSchema(workDir)).toEqual(schema); + }); + + it('should preserve imports when pulling with multi-file schema', async () => { + const { workDir } = await createProject('', { customPrelude: true }); + const schemaPath = path.join(workDir, 'zenstack/schema.zmodel'); + const modelsDir = path.join(workDir, 'zenstack/models'); + + fs.mkdirSync(modelsDir, { recursive: true }); + + // Create main schema with imports + const mainSchema = await formatDocument(`import './models/user' +import './models/post' + +${getDefaultPrelude()}`); + fs.writeFileSync(schemaPath, mainSchema); + + // Create user model + const userModel = await formatDocument(`import './post' + +model User { + id Int @id @default(autoincrement()) + email String @unique + name String? + posts Post[] + createdAt DateTime @default(now()) +}`); + fs.writeFileSync(path.join(modelsDir, 'user.zmodel'), userModel); + + // Create post model + const postModel = await formatDocument(`import './user' + +model Post { + id Int @id @default(autoincrement()) + title String + content String? + author User @relation(fields: [authorId], references: [id], onDelete: Cascade) + authorId Int + createdAt DateTime @default(now()) +}`); + fs.writeFileSync(path.join(modelsDir, 'post.zmodel'), postModel); + + runCli('db push', workDir); + + // Pull and verify imports are preserved + runCli('db pull --indent 4', workDir); + + const pulledMainSchema = fs.readFileSync(schemaPath).toString(); + const pulledUserSchema = fs.readFileSync(path.join(modelsDir, 'user.zmodel')).toString(); + const pulledPostSchema = fs.readFileSync(path.join(modelsDir, 'post.zmodel')).toString(); + + expect(pulledMainSchema).toEqual(mainSchema); + expect(pulledUserSchema).toEqual(userModel); + expect(pulledPostSchema).toEqual(postModel); + }); + }); + + describe('Pull should preserve enum declaration order', () => { + + it('should preserve interleaved enum and model ordering', async () => { + const { workDir, schema } = await createProject( + `enum Role { + USER + ADMIN +} + +model User { + id Int @id @default(autoincrement()) + email String @unique + role Role @default(USER) + status Status @default(ACTIVE) +} + +enum Status { + ACTIVE + INACTIVE + SUSPENDED +}`, + ); + runCli('db push', workDir); + + runCli('db pull --indent 4', workDir); + + // Enum-model-enum ordering should be preserved + expect(getSchema(workDir)).toEqual(schema); + }); + }); + + describe('Pull should consolidate shared enums', () => { + it('should consolidate per-column enums back to the original shared enum', async () => { + const { workDir, schema } = await createProject( + `enum Status { + ACTIVE + INACTIVE + SUSPENDED +} + +model User { + id Int @id @default(autoincrement()) + status Status @default(ACTIVE) +} + +model Group { + id Int @id @default(autoincrement()) + status Status @default(ACTIVE) +}`, + ); + runCli('db push', workDir); + + runCli('db pull --indent 4', workDir); + + // MySQL creates per-column enums (UserStatus, GroupStatus) but + // consolidation should map them back to the original shared Status enum + expect(getSchema(workDir)).toEqual(schema); + }); + + it('should consolidate per-column enums with --always-map without stale @@map', async () => { + // This test targets a bug where consolidateEnums renames keepEnum.name + // to oldEnum.name but leaves the synthetic @@map attribute added by + // syncEnums, so getDbName(keepEnum) still returns the old mapped name + // (e.g., 'UserStatus') instead of the consolidated name ('Status'), + // preventing matching in the downstream delete/add enum logic. + const { workDir } = await createProject( + `enum Status { + ACTIVE + INACTIVE + SUSPENDED +} + +model User { + id Int @id @default(autoincrement()) + status Status @default(ACTIVE) +} + +model Group { + id Int @id @default(autoincrement()) + status Status @default(ACTIVE) +}`, + ); + runCli('db push', workDir); + + runCli('db pull --indent 4 --always-map', workDir); + + const pulledSchema = getSchema(workDir); + + // The consolidated enum should be named Status, not UserStatus/GroupStatus + expect(pulledSchema).toContain('enum Status'); + expect(pulledSchema).not.toContain('enum UserStatus'); + expect(pulledSchema).not.toContain('enum GroupStatus'); + + // There should be no stale @@map referencing the synthetic per-column name + expect(pulledSchema).not.toMatch(/@@map\(['"]UserStatus['"]\)/); + expect(pulledSchema).not.toMatch(/@@map\(['"]GroupStatus['"]\)/); + }); + }); + + describe('Pull should preserve triple-slash comments on enums', () => { + it('should preserve triple-slash comments on enum declarations and fields', async () => { + const { workDir, schema } = await createProject( + `model User { + id Int @id @default(autoincrement()) + status Status @default(ACTIVE) +} + +/// User account status +/// ACTIVE - user can log in +/// INACTIVE - user is disabled +enum Status { + /// User can log in + ACTIVE + /// User is disabled + INACTIVE + /// User is suspended + SUSPENDED +}`, + ); + runCli('db push', workDir); + + runCli('db pull --indent 4', workDir); + + expect(getSchema(workDir)).toEqual(schema); + }); + }); + + describe('Pull should preserve data validation attributes', () => { + it('should preserve field-level validation attributes after db pull', async () => { + const { workDir, schema } = await createProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique @email + name String @length(min: 2, max: 100) + website String? @url + code String? @regex('^[A-Z]+$') + age Int @gt(0) + score Float @gte(0.0) + rating Decimal @lt(10) + rank BigInt @lte(999) +}`, + ); + runCli('db push', workDir); + + // Pull should preserve all validation attributes + runCli('db pull --indent 4', workDir); + + expect(getSchema(workDir)).toEqual(schema); + }); + + it('should preserve string transformation attributes after db pull', async () => { + const { workDir, schema } = await createProject( + `model Setting { + id Int @id @default(autoincrement()) + key String @trim @lower + value String @trim @upper +}`, + ); + runCli('db push', workDir); + + runCli('db pull --indent 4', workDir); + + expect(getSchema(workDir)).toEqual(schema); + }); + + it('should preserve model-level @@validate attribute after db pull', async () => { + const { workDir, schema } = await createProject( + `model Product { + id Int @id @default(autoincrement()) + minPrice Decimal @default(0.00) + maxPrice Decimal @default(100.00) + + @@validate(minPrice < maxPrice, 'minPrice must be less than maxPrice') +}`, + ); + runCli('db push', workDir); + + runCli('db pull --indent 4', workDir); + + expect(getSchema(workDir)).toEqual(schema); + }); + }); + + describe('Pull should update existing field definitions when database changes', () => { + it('should update field type when database column type changes', async () => { + // Step 1: Create initial schema with String field + const { workDir } = await createProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique + age String +}`, + ); + runCli('db push', workDir); + + // Step 2: Modify schema to change age from String to Int + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + const updatedSchema = await formatDocument(`${getDefaultPrelude()} + +model User { + id Int @id @default(autoincrement()) + email String @unique + age Int +}`); + fs.writeFileSync(schemaFile, updatedSchema); + runCli('db push', workDir); + + // Step 3: Revert schema back to original (with String type) + const originalSchema = await formatDocument(`${getDefaultPrelude()} + +model User { + id Int @id @default(autoincrement()) + email String @unique + age String +}`); + fs.writeFileSync(schemaFile, originalSchema); + + // Step 4: Pull from database - should detect that age is now Int + runCli('db pull --indent 4', workDir); + + // Step 5: Verify that pulled schema has Int type (matching database) + const pulledSchema = getSchema(workDir); + expect(pulledSchema).toEqual(updatedSchema); + }); + + it('should update field optionality when database column nullability changes', async () => { + // Step 1: Create initial schema with required field + const { workDir } = await createProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique + name String +}`, + ); + runCli('db push', workDir); + + // Step 2: Modify schema to make name optional + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + const updatedSchema = await formatDocument(`${getDefaultPrelude()} + +model User { + id Int @id @default(autoincrement()) + email String @unique + name String? +}`); + fs.writeFileSync(schemaFile, updatedSchema); + runCli('db push', workDir); + + // Step 3: Revert schema back to original (with required name) + const originalSchema = await formatDocument(`${getDefaultPrelude()} + +model User { + id Int @id @default(autoincrement()) + email String @unique + name String +}`); + fs.writeFileSync(schemaFile, originalSchema); + + // Step 4: Pull from database - should detect that name is now optional + runCli('db pull --indent 4', workDir); + + // Step 5: Verify that pulled schema has optional name (matching database) + const pulledSchema = getSchema(workDir); + expect(pulledSchema).toEqual(updatedSchema); + }); + + it('should update default value when database default changes', async () => { + // Step 1: Create initial schema with default value + const { workDir } = await createProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique + status String @default('active') +}`, + ); + runCli('db push', workDir); + + // Step 2: Modify schema to change default value + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + const updatedSchema = await formatDocument(`${getDefaultPrelude()} + +model User { + id Int @id @default(autoincrement()) + email String @unique + status String @default('pending') +}`); + fs.writeFileSync(schemaFile, updatedSchema); + runCli('db push', workDir); + + // Step 3: Revert schema back to original default + const originalSchema = await formatDocument(`${getDefaultPrelude()} + +model User { + id Int @id @default(autoincrement()) + email String @unique + status String @default('active') +}`); + fs.writeFileSync(schemaFile, originalSchema); + + // Step 4: Pull from database - should detect that default changed + runCli('db pull --indent 4', workDir); + + // Step 5: Verify that pulled schema has updated default (matching database) + const pulledSchema = getSchema(workDir); + expect(pulledSchema).toEqual(updatedSchema); + }); + }); +}); + +describe('DB pull - PostgreSQL specific features', () => { + it('should restore schema with multiple database schemas', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'postgresql') { + skip(); + return; + } + const { workDir, schema } = await createProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique + posts Post[] + + @@schema('auth') +} + +model Post { + id Int @id @default(autoincrement()) + title String + author User @relation(fields: [authorId], references: [id], onDelete: Cascade) + authorId Int + + @@schema('content') +}`, + { provider: 'postgresql', datasourceFields:{ schemas: ['public', 'content', 'auth'] } }, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + + fs.writeFileSync(schemaFile, getDefaultPrelude({ provider: 'postgresql', datasourceFields:{ schemas: ['public', 'content', 'auth']} })); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(schema); + }); + + it('should preserve native PostgreSQL enums when schema exists', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'postgresql') { + skip(); + return; + } + const { workDir, schema } = await createProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique + status Status @default(ACTIVE) + role Role @default(USER) +} + +enum Status { + ACTIVE + INACTIVE + SUSPENDED +} + +enum Role { + USER + ADMIN + MODERATOR +}`, + { provider: 'postgresql' }, + ); + runCli('db push', workDir); + + runCli('db pull --indent 4', workDir); + const pulledSchema = getSchema(workDir); + + expect(pulledSchema).toEqual(schema); + }); + + it('should not modify schema with PostgreSQL-specific features', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'postgresql') { + skip(); + return; + } + const { workDir, schema } = await createProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique + status Status @default(ACTIVE) + posts Post[] + metadata Json? + + @@schema('auth') + @@index([status]) +} + +model Post { + id Int @id @default(autoincrement()) + title String + author User @relation(fields: [authorId], references: [id], onDelete: Cascade) + authorId Int + tags String[] + + @@schema('content') + @@index([authorId]) +} + +enum Status { + ACTIVE + INACTIVE + SUSPENDED +}`, + { provider: 'postgresql', datasourceFields:{ schemas: ['public', 'content', 'auth'] } }, + ); + runCli('db push', workDir); + + runCli('db pull --indent 4', workDir); + + expect(getSchema(workDir)).toEqual(schema); + }); + + it('should restore native type attributes from PostgreSQL typnames', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'postgresql') { + skip(); + return; + } + // PostgreSQL introspection returns typnames like 'int2', 'float8', 'bpchar', + // but Prisma/ZenStack attributes are named @db.SmallInt, @db.DoublePrecision, @db.Char, etc. + // This test verifies the mapping works correctly. + // Note: Default native types (jsonb for Json, bytea for Bytes) are not added when pulling from zero + // because they match the default database type for that field type. + const { workDir } = await createProject( + `model TypeTest { + id Int @id @default(autoincrement()) + smallNumber Int @db.SmallInt() + realNumber Float @db.Real() + doubleNum Float @db.DoublePrecision() + fixedChar String @db.Char(10) + uuid String @db.Uuid() + jsonData Json @db.Json() + jsonbData Json @db.JsonB() + binaryData Bytes @db.ByteA() +}`, + { provider: 'postgresql' }, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + + // Remove schema content to simulate restoration from zero + fs.writeFileSync(schemaFile, getDefaultPrelude({ provider: 'postgresql' })); + + // Pull should restore non-default native type attributes + // Default types (jsonb for Json, bytea for Bytes) are not added + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + // Verify key native type mappings are restored correctly: + // - @db.SmallInt for int2 (non-default for Int which defaults to integer/int4) + // - @db.Real for float4 (non-default for Float which defaults to double precision/float8) + // - @db.Char(10) for bpchar with length (non-default for String which defaults to text) + // - @db.Uuid for uuid (non-default for String which defaults to text) + // - @db.Json for json (non-default for Json which defaults to jsonb) + expect(restoredSchema).toContain('@db.SmallInt'); + expect(restoredSchema).toContain('@db.Real'); + expect(restoredSchema).toContain('@db.Char(10)'); + expect(restoredSchema).toContain('@db.Uuid'); + expect(restoredSchema).toContain('@db.Json'); + // Default types should NOT be added when pulling from zero + expect(restoredSchema).not.toContain('@db.Integer'); // integer is default for Int + expect(restoredSchema).not.toContain('@db.DoublePrecision'); // double precision is default for Float + expect(restoredSchema).not.toContain('@db.JsonB'); // jsonb is default for Json + expect(restoredSchema).not.toContain('@db.ByteA'); // bytea is default for Bytes + }); + + it('should correctly map composite foreign key columns by position', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'postgresql') { + skip(); + return; + } + // Composite FK: (tenantId, authorId) REFERENCES Tenant(tenantId, userId) + // The introspection must correlate by position, not match each source column + // to every target column. Without the fix, tenantId would incorrectly map to + // both tenantId AND userId in the target table. + const { workDir, schema } = await createProject( + `model Post { + id Int @id @default(autoincrement()) + title String + tenant Tenant @relation(fields: [tenantId, authorId], references: [tenantId, userId], onDelete: Cascade) + tenantId Int + authorId Int + + @@index([tenantId, authorId]) +} + +model Tenant { + tenantId Int + userId Int + name String + posts Post[] + + @@id([tenantId, userId]) +}`, + { provider: 'postgresql' }, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + + fs.writeFileSync(schemaFile, getDefaultPrelude({ provider: 'postgresql' })); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(schema); + }); + + it('should pull stored generated columns as Unsupported with full expression', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'postgresql') { + skip(); + return; + } + // PostgreSQL supports GENERATED ALWAYS AS (expr) STORED since PG 12. + // The introspection should include the full generation expression in the + // datatype so it is rendered as Unsupported("type GENERATED ALWAYS AS (expr) STORED"). + + // 1. Create a project with a base table (we need the DB to exist first) + const { workDir } = await createProject( + `model ComputedUsers { + id Int @id @default(autoincrement()) + firstName String + lastName String +}`, + { provider: 'postgresql' }, + ); + runCli('db push', workDir); + + // 2. Add a generated column via raw SQL (can't be defined in ZModel) + const { Client } = await import('pg'); + const dbName = getTestDbName('postgresql'); + const client = new Client({ connectionString: getTestDbUrl('postgresql', dbName) }); + await client.connect(); + try { + await client.query( + `ALTER TABLE "ComputedUsers" ADD COLUMN "fullName" text GENERATED ALWAYS AS ("firstName" || ' ' || "lastName") STORED` + ); + } finally { + await client.end(); + } + + // 3. Pull from zero + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + fs.writeFileSync(schemaFile, getDefaultPrelude({ provider: 'postgresql' })); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + + // The generated column should be pulled as Unsupported with the full expression. + // format_type returns 'text', and pg_get_expr returns the expression. + expect(restoredSchema).toEqual(await formatDocument(`${getDefaultPrelude({ provider: 'postgresql' })} + +model ComputedUsers { + id Int @id @default(autoincrement()) + firstName String + lastName String + fullName Unsupported('text GENERATED ALWAYS AS ((("firstName" || \\' \\'::text) || "lastName")) STORED')? +}`)); + }); + + it('should pull virtual generated columns as Unsupported with full expression', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'postgresql') { + skip(); + return; + } + // PostgreSQL 17+ supports VIRTUAL generated columns. + // For earlier versions, only STORED is supported, so this test may need to be + // adapted. We test STORED here since it's universally supported. + + const { workDir } = await createProject( + `model ComputedProducts { + id Int @id @default(autoincrement()) + price Int @default(0) + qty Int @default(0) +}`, + { provider: 'postgresql' }, + ); + runCli('db push', workDir); + + const { Client } = await import('pg'); + const dbName = getTestDbName('postgresql'); + const client = new Client({ connectionString: getTestDbUrl('postgresql', dbName) }); + await client.connect(); + try { + await client.query( + `ALTER TABLE "ComputedProducts" ADD COLUMN "total" integer GENERATED ALWAYS AS ("price" * "qty") STORED` + ); + } finally { + await client.end(); + } + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + fs.writeFileSync(schemaFile, getDefaultPrelude({ provider: 'postgresql' })); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + + expect(restoredSchema).toEqual(await formatDocument(`${getDefaultPrelude({ provider: 'postgresql' })} + +model ComputedProducts { + id Int @id @default(autoincrement()) + price Int @default(0) + qty Int @default(0) + total Unsupported('integer GENERATED ALWAYS AS ((price * qty)) STORED')? +}`)); + }); +}); + +describe('DB pull - MySQL specific features', () => { + it('should detect single-column unique indexes via STATISTICS', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'mysql') { + skip(); + return; + } + // MySQL's COLUMN_KEY may not reliably reflect unique indexes in all cases. + // The introspection should also check INFORMATION_SCHEMA.STATISTICS for + // NON_UNIQUE = 0 single-column indexes to correctly detect uniqueness, + // so that the index-processing skip logic (which checks index.unique + + // single-column) doesn't cause a missing @unique attribute. + const { workDir, schema } = await createProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique + nickname String? @unique +}`, + { provider: 'mysql' }, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + + // Pull from zero to test introspection detects unique columns correctly + fs.writeFileSync(schemaFile, getDefaultPrelude({ provider: 'mysql' })); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(schema); + }); + + it('should pull stored generated columns as Unsupported with full expression', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'mysql') { + skip(); + return; + } + // MySQL supports both VIRTUAL and STORED generated columns. + // The introspection should include the full generation expression in the + // datatype so it is rendered as Unsupported("type GENERATED ALWAYS AS (expr) STORED"). + + // 1. Create a project with a base table (we need the DB to exist first) + const { workDir } = await createProject( + `model ComputedUsers { + id Int @id @default(autoincrement()) + firstName String @db.VarChar(255) + lastName String @db.VarChar(255) +}`, + { provider: 'mysql' }, + ); + runCli('db push', workDir); + + // 2. Add a generated column via raw SQL (can't be defined in ZModel) + const mysql = await import('mysql2/promise'); + const dbName = getTestDbName('mysql'); + const connection = await mysql.createConnection(getTestDbUrl('mysql', dbName)); + try { + await connection.execute( + "ALTER TABLE `ComputedUsers` ADD COLUMN `fullName` varchar(511) GENERATED ALWAYS AS (CONCAT(`firstName`, ' ', `lastName`)) STORED" + ); + } finally { + await connection.end(); + } + + // 3. Pull from zero + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + fs.writeFileSync(schemaFile, getDefaultPrelude({ provider: 'mysql' })); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + + // The generated column should be pulled as Unsupported with the full expression. + // MySQL uses COLUMN_TYPE (e.g., 'varchar(511)') and GENERATION_EXPRESSION for the expr, + // and EXTRA contains 'STORED GENERATED' or 'VIRTUAL GENERATED'. + expect(restoredSchema).toEqual(await formatDocument(`${getDefaultPrelude({ provider: 'mysql' })} + +model ComputedUsers { + id Int @id @default(autoincrement()) + firstName String @db.VarChar(255) + lastName String @db.VarChar(255) + fullName Unsupported('varchar(511) GENERATED ALWAYS AS (concat(\`firstName\`,\\' \\',\`lastName\`)) STORED')? +}`)); + }); + + it('should pull virtual generated columns as Unsupported with full expression', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'mysql') { + skip(); + return; + } + + const { workDir } = await createProject( + `model ComputedProducts { + id Int @id @default(autoincrement()) + price Int @default(0) + qty Int @default(0) +}`, + { provider: 'mysql' }, + ); + runCli('db push', workDir); + + const mysql = await import('mysql2/promise'); + const dbName = getTestDbName('mysql'); + const connection = await mysql.createConnection(getTestDbUrl('mysql', dbName)); + try { + await connection.execute( + "ALTER TABLE `ComputedProducts` ADD COLUMN `total` int GENERATED ALWAYS AS (`price` * `qty`) VIRTUAL" + ); + } finally { + await connection.end(); + } + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + fs.writeFileSync(schemaFile, getDefaultPrelude({ provider: 'mysql' })); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + + expect(restoredSchema).toEqual(await formatDocument(`${getDefaultPrelude({ provider: 'mysql' })} + +model ComputedProducts { + id Int @id @default(autoincrement()) + price Int @default(0) + qty Int @default(0) + total Unsupported('int GENERATED ALWAYS AS ((\`price\` * \`qty\`)) VIRTUAL')? +}`)); + }); +}); + +describe('DB pull - SQLite specific features', () => { + it('should restore composite foreign key relations', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'sqlite') { + skip(); + return; + } + // Composite FK: (tenantId, authorId) REFERENCES Tenant(tenantId, userId). + // The SQLite introspection extracts FK constraint names by parsing the + // CREATE TABLE DDL. The current regex only captures a single column inside + // FOREIGN KEY(...), so composite FK constraint names are lost. Without a + // constraint name, the downstream relation grouping (pull/index.ts) skips + // the FK columns entirely and the relation is not restored. + const { workDir, schema } = await createProject( + `model Post { + id Int @id @default(autoincrement()) + title String + tenant Tenant @relation(fields: [tenantId, authorId], references: [tenantId, userId], onDelete: Cascade) + tenantId Int + authorId Int + + @@index([tenantId, authorId]) +} + +model Tenant { + tenantId Int + userId Int + name String + posts Post[] + + @@id([tenantId, userId]) +}`, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + + fs.writeFileSync(schemaFile, getDefaultPrelude()); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(schema); + }); + + it('should map columns without a declared type to Bytes', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'sqlite') { + skip(); + return; + } + // Create a minimal project and push to get the database file. + const { workDir } = await createProject(""); + + // Open the SQLite database directly and add a table with an untyped column. + // In SQLite, CREATE TABLE t("data") gives column "data" no declared type, + // which per affinity rules means BLOB affinity — should map to Bytes. + const dbPath = path.join(workDir, 'zenstack', 'test.db'); + const SQLite = (await import('better-sqlite3')).default; + const db = new SQLite(dbPath); + db.exec('CREATE TABLE "UntypedTest" ("id" INTEGER PRIMARY KEY AUTOINCREMENT, "data")'); + db.close(); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + fs.writeFileSync(schemaFile, getDefaultPrelude()); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + // The untyped "data" column should be pulled as Bytes (BLOB affinity), + // not as Unsupported. + expect(restoredSchema).toContain('data Bytes?'); + expect(restoredSchema).not.toContain('Unsupported'); + }); + + it('should pull stored generated columns as Unsupported', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'sqlite') { + skip(); + return; + } + // SQLite PRAGMA table_xinfo reports generated columns with hidden values: + // hidden = 2 → VIRTUAL generated column + // hidden = 3 → STORED generated column + // Both types should be pulled as Unsupported("full type definition") + // because generated columns are read-only and cannot be written to. + + const { workDir } = await createProject(''); + + const dbPath = path.join(workDir, 'zenstack', 'test.db'); + const SQLite = (await import('better-sqlite3')).default; + const db = new SQLite(dbPath); + db.exec(` + CREATE TABLE "ComputedUsers" ( + "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + "firstName" TEXT NOT NULL, + "lastName" TEXT NOT NULL, + "fullName" TEXT GENERATED ALWAYS AS (firstName || ' ' || lastName) STORED + ) + `); + db.close(); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + fs.writeFileSync(schemaFile, getDefaultPrelude()); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + + // first_name and last_name should be regular String fields + expect(restoredSchema).toEqual(await formatDocument(`${getDefaultPrelude()} + +model ComputedUsers { + id Int @id @default(autoincrement()) + firstName String + lastName String + fullName Unsupported('TEXT GENERATED ALWAYS AS (firstName || \\' \\' || lastName) STORED')? +}`)); + }); + + it('should pull virtual generated columns as Unsupported', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'sqlite') { + skip(); + return; + } + + const { workDir } = await createProject(''); + + const dbPath = path.join(workDir, 'zenstack', 'test.db'); + const SQLite = (await import('better-sqlite3')).default; + const db = new SQLite(dbPath); + db.exec(` + CREATE TABLE "ComputedProducts" ( + "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + "price" INTEGER NOT NULL DEFAULT 0, + "qty" INTEGER NOT NULL DEFAULT 0, + "total" INTEGER GENERATED ALWAYS AS ("price" * "qty") VIRTUAL + ) + `); + db.close(); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + fs.writeFileSync(schemaFile, getDefaultPrelude()); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + + expect(restoredSchema).toEqual(await formatDocument(`${getDefaultPrelude()} + +model ComputedProducts { + id Int @id @default(autoincrement()) + price Int @default(0) + qty Int @default(0) + total Unsupported('INTEGER GENERATED ALWAYS AS ("price" * "qty") VIRTUAL')? +}`)); + }); +}); + +describe('DB pull - SQL specific features', () => { + it('should restore enum fields from zero', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'mysql' && provider !== 'postgresql') { + skip(); + return; + } + + const { workDir, schema } = await createProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique + status UserStatus @default(ACTIVE) +} + +enum UserStatus { + ACTIVE + INACTIVE + SUSPENDED +}`); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + + // Remove schema content to simulate restoration from zero + fs.writeFileSync(schemaFile, getDefaultPrelude()); + + // Pull should fully restore the schema including enum fields + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toContain(`model User { + id Int @id @default(autoincrement()) + email String @unique + status UserStatus @default(ACTIVE) +}`); + + expect(restoredSchema).toContain(`enum UserStatus { + ACTIVE + INACTIVE + SUSPENDED +}`); + }); +}); diff --git a/packages/cli/test/db/push.test.ts b/packages/cli/test/db/push.test.ts new file mode 100644 index 000000000..bba9e05bd --- /dev/null +++ b/packages/cli/test/db/push.test.ts @@ -0,0 +1,18 @@ +import fs from 'node:fs'; +import path from 'node:path'; +import { describe, expect, it } from 'vitest'; +import { createProject, runCli } from '../utils'; + +const model = ` +model User { + id String @id @default(cuid()) +} +`; + +describe('CLI db commands test', () => { + it('should generate a database with db push', async () => { + const { workDir } = await createProject(model, { provider: 'sqlite' }); + runCli('db push', workDir); + expect(fs.existsSync(path.join(workDir, 'zenstack/test.db'))).toBe(true); + }); +}); diff --git a/packages/cli/test/format.test.ts b/packages/cli/test/format.test.ts index 9c95960a3..0bfb32d7f 100644 --- a/packages/cli/test/format.test.ts +++ b/packages/cli/test/format.test.ts @@ -10,8 +10,8 @@ model User { `; describe('CLI format command test', () => { - it('should format a valid schema successfully', () => { - const workDir = createProject(model); + it('should format a valid schema successfully', async () => { + const { workDir } = await createProject(model); expect(() => runCli('format', workDir)).not.toThrow(); const updatedContent = fs.readFileSync(`${workDir}/zenstack/schema.zmodel`, 'utf-8'); expect( @@ -22,12 +22,12 @@ describe('CLI format command test', () => { ).toBeTruthy(); }); - it('should silently ignore invalid schema', () => { + it('should silently ignore invalid schema', async () => { const invalidModel = ` model User { id String @id @default(cuid()) `; - const workDir = createProject(invalidModel); + const { workDir } = await createProject(invalidModel); expect(() => runCli('format', workDir)).not.toThrow(); }); }); diff --git a/packages/cli/test/generate.test.ts b/packages/cli/test/generate.test.ts index 074e88e56..6b270b4a8 100644 --- a/packages/cli/test/generate.test.ts +++ b/packages/cli/test/generate.test.ts @@ -10,28 +10,28 @@ model User { `; describe('CLI generate command test', () => { - it('should generate a TypeScript schema', () => { - const workDir = createProject(model); + it('should generate a TypeScript schema', async () => { + const { workDir } = await createProject(model); runCli('generate', workDir); expect(fs.existsSync(path.join(workDir, 'zenstack/schema.ts'))).toBe(true); expect(fs.existsSync(path.join(workDir, 'zenstack/schema.prisma'))).toBe(false); }); - it('should respect custom output directory', () => { - const workDir = createProject(model); + it('should respect custom output directory', async () => { + const { workDir } = await createProject(model); runCli('generate --output ./zen', workDir); expect(fs.existsSync(path.join(workDir, 'zen/schema.ts'))).toBe(true); }); - it('should respect custom schema location', () => { - const workDir = createProject(model); + it('should respect custom schema location', async () => { + const { workDir } = await createProject(model); fs.renameSync(path.join(workDir, 'zenstack/schema.zmodel'), path.join(workDir, 'zenstack/foo.zmodel')); runCli('generate --schema ./zenstack/foo.zmodel', workDir); expect(fs.existsSync(path.join(workDir, 'zenstack/schema.ts'))).toBe(true); }); - it('should respect package.json config', () => { - const workDir = createProject(model); + it('should respect package.json config', async () => { + const { workDir } = await createProject(model); fs.mkdirSync(path.join(workDir, 'foo')); fs.renameSync(path.join(workDir, 'zenstack/schema.zmodel'), path.join(workDir, 'foo/schema.zmodel')); fs.rmdirSync(path.join(workDir, 'zenstack')); @@ -45,8 +45,8 @@ describe('CLI generate command test', () => { expect(fs.existsSync(path.join(workDir, 'bar/schema.ts'))).toBe(true); }); - it('should respect package.json schema dir config', () => { - const workDir = createProject(model); + it('should respect package.json schema dir config', async () => { + const { workDir } = await createProject(model); fs.mkdirSync(path.join(workDir, 'foo')); fs.renameSync(path.join(workDir, 'zenstack/schema.zmodel'), path.join(workDir, 'foo/schema.zmodel')); fs.rmdirSync(path.join(workDir, 'zenstack')); @@ -60,15 +60,15 @@ describe('CLI generate command test', () => { expect(fs.existsSync(path.join(workDir, 'bar/schema.ts'))).toBe(true); }); - it('should respect lite option', () => { - const workDir = createProject(model); + it('should respect lite option', async () => { + const { workDir } = await createProject(model); runCli('generate --lite', workDir); expect(fs.existsSync(path.join(workDir, 'zenstack/schema.ts'))).toBe(true); expect(fs.existsSync(path.join(workDir, 'zenstack/schema-lite.ts'))).toBe(true); }); - it('should respect liteOnly option', () => { - const workDir = createProject(model); + it('should respect liteOnly option', async () => { + const { workDir } = await createProject(model); runCli('generate --lite-only', workDir); expect(fs.existsSync(path.join(workDir, 'zenstack/schema.ts'))).toBe(false); expect(fs.existsSync(path.join(workDir, 'zenstack/schema-lite.ts'))).toBe(true); diff --git a/packages/cli/test/migrate.test.ts b/packages/cli/test/migrate.test.ts index 56a0fec83..bb3a7cd53 100644 --- a/packages/cli/test/migrate.test.ts +++ b/packages/cli/test/migrate.test.ts @@ -10,37 +10,37 @@ model User { `; describe('CLI migrate commands test', () => { - it('should generate a database with migrate dev', () => { - const workDir = createProject(model); + it('should generate a database with migrate dev', async () => { + const { workDir } = await createProject(model, { provider: 'sqlite' }); runCli('migrate dev --name init', workDir); - expect(fs.existsSync(path.join(workDir, 'zenstack/dev.db'))).toBe(true); + expect(fs.existsSync(path.join(workDir, 'zenstack/test.db'))).toBe(true); expect(fs.existsSync(path.join(workDir, 'zenstack/migrations'))).toBe(true); }); - it('should reset the database with migrate reset', () => { - const workDir = createProject(model); + it('should reset the database with migrate reset', async () => { + const { workDir } = await createProject(model, { provider: 'sqlite' }); runCli('db push', workDir); - expect(fs.existsSync(path.join(workDir, 'zenstack/dev.db'))).toBe(true); + expect(fs.existsSync(path.join(workDir, 'zenstack/test.db'))).toBe(true); runCli('migrate reset --force', workDir); - expect(fs.existsSync(path.join(workDir, 'zenstack/dev.db'))).toBe(true); + expect(fs.existsSync(path.join(workDir, 'zenstack/test.db'))).toBe(true); }); - it('should reset the database with migrate deploy', () => { - const workDir = createProject(model); + it('should reset the database with migrate deploy', async () => { + const { workDir } = await createProject(model, { provider: 'sqlite' }); runCli('migrate dev --name init', workDir); - fs.rmSync(path.join(workDir, 'zenstack/dev.db')); + fs.rmSync(path.join(workDir, 'zenstack/test.db')); runCli('migrate deploy', workDir); - expect(fs.existsSync(path.join(workDir, 'zenstack/dev.db'))).toBe(true); + expect(fs.existsSync(path.join(workDir, 'zenstack/test.db'))).toBe(true); }); - it('supports migrate status', () => { - const workDir = createProject(model); + it('supports migrate status', async () => { + const { workDir } = await createProject(model, { provider: 'sqlite' }); runCli('migrate dev --name init', workDir); runCli('migrate status', workDir); }); - it('supports migrate resolve', () => { - const workDir = createProject(model); + it('supports migrate resolve', async () => { + const { workDir } = await createProject(model, { provider: 'sqlite' }); runCli('migrate dev --name init', workDir); // find the migration record "timestamp_init" @@ -51,7 +51,7 @@ describe('CLI migrate commands test', () => { fs.writeFileSync(path.join(workDir, 'zenstack/migrations', migration!, 'migration.sql'), 'invalid content'); // redeploy the migration, which will fail - fs.rmSync(path.join(workDir, 'zenstack/dev.db'), { force: true }); + fs.rmSync(path.join(workDir, 'zenstack/test.db'), { force: true }); try { runCli('migrate deploy', workDir); } catch { @@ -65,8 +65,8 @@ describe('CLI migrate commands test', () => { runCli(`migrate resolve --applied ${migration}`, workDir); }); - it('should throw error when neither applied nor rolled-back is provided', () => { - const workDir = createProject(model); + it('should throw error when neither applied nor rolled-back is provided', async () => { + const { workDir } = await createProject(model, { provider: 'sqlite' }); expect(() => runCli('migrate resolve', workDir)).toThrow(); }); }); diff --git a/packages/cli/test/plugins/custom-plugin.test.ts b/packages/cli/test/plugins/custom-plugin.test.ts index 084bf9cd7..3492dbbe6 100644 --- a/packages/cli/test/plugins/custom-plugin.test.ts +++ b/packages/cli/test/plugins/custom-plugin.test.ts @@ -5,8 +5,8 @@ import { createProject, runCli } from '../utils'; import { execSync } from 'node:child_process'; describe('Custom plugins tests', () => { - it('runs custom plugin generator', () => { - const workDir = createProject(` + it('runs custom plugin generator', async () => { + const { workDir } = await createProject(` plugin custom { provider = '../my-plugin.js' output = '../custom-output' diff --git a/packages/cli/test/plugins/prisma-plugin.test.ts b/packages/cli/test/plugins/prisma-plugin.test.ts index 739252e21..4b619c679 100644 --- a/packages/cli/test/plugins/prisma-plugin.test.ts +++ b/packages/cli/test/plugins/prisma-plugin.test.ts @@ -4,8 +4,8 @@ import { describe, expect, it } from 'vitest'; import { createProject, runCli } from '../utils'; describe('Core plugins tests', () => { - it('can automatically generate a TypeScript schema with default output', () => { - const workDir = createProject(` + it('can automatically generate a TypeScript schema with default output', async () => { + const { workDir } = await createProject(` model User { id String @id @default(cuid()) } @@ -14,8 +14,8 @@ model User { expect(fs.existsSync(path.join(workDir, 'zenstack/schema.ts'))).toBe(true); }); - it('can automatically generate a TypeScript schema with custom output', () => { - const workDir = createProject(` + it('can automatically generate a TypeScript schema with custom output', async () => { + const { workDir } = await createProject(` plugin typescript { provider = '@core/typescript' output = '../generated-schema' @@ -29,8 +29,8 @@ model User { expect(fs.existsSync(path.join(workDir, 'generated-schema/schema.ts'))).toBe(true); }); - it('can generate a Prisma schema with default output', () => { - const workDir = createProject(` + it('can generate a Prisma schema with default output', async () => { + const { workDir } = await createProject(` plugin prisma { provider = '@core/prisma' } @@ -43,8 +43,8 @@ model User { expect(fs.existsSync(path.join(workDir, 'zenstack/schema.prisma'))).toBe(true); }); - it('can generate a Prisma schema with custom output', () => { - const workDir = createProject(` + it('can generate a Prisma schema with custom output', async () => { + const { workDir } = await createProject(` plugin prisma { provider = '@core/prisma' output = '../prisma/schema.prisma' @@ -58,8 +58,8 @@ model User { expect(fs.existsSync(path.join(workDir, 'prisma/schema.prisma'))).toBe(true); }); - it('can generate a Prisma schema with custom output relative to zenstack.output', () => { - const workDir = createProject(` + it('can generate a Prisma schema with custom output relative to zenstack.output', async () => { + const { workDir } = await createProject(` plugin prisma { provider = '@core/prisma' output = './schema.prisma' diff --git a/packages/cli/test/utils.ts b/packages/cli/test/utils.ts index 2fafb2074..31a86dfb9 100644 --- a/packages/cli/test/utils.ts +++ b/packages/cli/test/utils.ts @@ -1,20 +1,107 @@ -import { createTestProject } from '@zenstackhq/testtools'; +import { createTestProject, getTestDbProvider } from '@zenstackhq/testtools'; +import { createHash } from 'node:crypto'; import { execSync } from 'node:child_process'; import fs from 'node:fs'; import path from 'node:path'; +import { expect } from 'vitest'; +import { formatDocument } from '@zenstackhq/language'; -const ZMODEL_PRELUDE = `datasource db { - provider = "sqlite" - url = "file:./dev.db" +const TEST_PG_CONFIG = { + host: process.env['TEST_PG_HOST'] ?? 'localhost', + port: process.env['TEST_PG_PORT'] ? parseInt(process.env['TEST_PG_PORT']) : 5432, + user: process.env['TEST_PG_USER'] ?? 'postgres', + password: process.env['TEST_PG_PASSWORD'] ?? 'postgres', +}; + +const TEST_MYSQL_CONFIG = { + host: process.env['TEST_MYSQL_HOST'] ?? 'localhost', + port: process.env['TEST_MYSQL_PORT'] ? parseInt(process.env['TEST_MYSQL_PORT']) : 3306, + user: process.env['TEST_MYSQL_USER'] ?? 'root', + password: process.env['TEST_MYSQL_PASSWORD'] ?? 'mysql', +}; + +export function getTestDbName(provider: string) { + if (provider === 'sqlite') { + return './test.db'; + } + const testName = expect.getState().currentTestName ?? 'unnamed'; + const testPath = expect.getState().testPath ?? ''; + // digest test name + const digest = createHash('md5') + .update(testName + testPath) + .digest('hex'); + // compute a database name based on test name + return ( + 'test_' + + testName + .toLowerCase() + .replace(/[^a-z0-9_]/g, '_') + .replace(/_+/g, '_') + .substring(0, 30) + + digest.slice(0, 6) + ); +} + +export function getTestDbUrl(provider: 'sqlite' | 'postgresql' | 'mysql', dbName: string): string { + switch (provider) { + case 'sqlite': + return `file:${dbName}`; + case 'postgresql': + return `postgres://${TEST_PG_CONFIG.user}:${TEST_PG_CONFIG.password}@${TEST_PG_CONFIG.host}:${TEST_PG_CONFIG.port}/${dbName}`; + case 'mysql': + return `mysql://${TEST_MYSQL_CONFIG.user}:${TEST_MYSQL_CONFIG.password}@${TEST_MYSQL_CONFIG.host}:${TEST_MYSQL_CONFIG.port}/${dbName}`; + default: + throw new Error(`Unsupported provider: ${provider}`); + } +} + +export function getDefaultPrelude(options?: { provider?: 'sqlite' | 'postgresql' | 'mysql', datasourceFields?: Record }) { + const provider = (options?.provider || getTestDbProvider()) ?? 'sqlite'; + const dbName = getTestDbName(provider); + let dbUrl: string; + + switch (provider) { + case 'sqlite': + dbUrl = `file:${dbName}`; + break; + case 'postgresql': + dbUrl = `postgres://${TEST_PG_CONFIG.user}:${TEST_PG_CONFIG.password}@${TEST_PG_CONFIG.host}:${TEST_PG_CONFIG.port}/${dbName}`; + break; + case 'mysql': + dbUrl = `mysql://${TEST_MYSQL_CONFIG.user}:${TEST_MYSQL_CONFIG.password}@${TEST_MYSQL_CONFIG.host}:${TEST_MYSQL_CONFIG.port}/${dbName}`; + break; + default: + throw new Error(`Unsupported provider: ${provider}`); + } + // Build fields array for proper alignment (matching ZModelCodeGenerator) + const fields: [string, string][] = [ + ['provider', `'${provider}'`], + ['url', `'${dbUrl}'`], + ...Object.entries(options?.datasourceFields || {}).map(([k, v]) => { + const value = Array.isArray(v) ? `[${v.map(item => `'${item}'`).join(', ')}]` : `'${v}'`; + return [k, value] as [string, string]; + }), + ]; + + const formattedFields = fields.map(([name, value]) => { + return ` ${name} = ${value}`; + }).join('\n'); + + const ZMODEL_PRELUDE = `datasource db {\n${formattedFields}\n}`; + return ZMODEL_PRELUDE; } -`; -export function createProject(zmodel: string, addPrelude = true) { +export async function createProject( + zmodel: string, + options?: { customPrelude?: boolean; provider?: 'sqlite' | 'postgresql' | 'mysql'; datasourceFields?: Record }, +) { const workDir = createTestProject(); fs.mkdirSync(path.join(workDir, 'zenstack'), { recursive: true }); const schemaPath = path.join(workDir, 'zenstack/schema.zmodel'); - fs.writeFileSync(schemaPath, addPrelude ? `${ZMODEL_PRELUDE}\n\n${zmodel}` : zmodel); - return workDir; + const content = options?.customPrelude ? zmodel : `${getDefaultPrelude({ provider: options?.provider, datasourceFields: options?.datasourceFields })}\n\n${zmodel}`; + const schema = await formatDocument(content); + fs.writeFileSync(schemaPath, schema); + return { workDir, schema }; } export function runCli(command: string, cwd: string) { diff --git a/packages/clients/tanstack-query/test/schemas/basic/models.ts b/packages/clients/tanstack-query/test/schemas/basic/models.ts index a4bcedbbd..84371224a 100644 --- a/packages/clients/tanstack-query/test/schemas/basic/models.ts +++ b/packages/clients/tanstack-query/test/schemas/basic/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { type SchemaType as $Schema } from "./schema-lite"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; export type User = $ModelResult<$Schema, "User">; export type Post = $ModelResult<$Schema, "Post">; export type Category = $ModelResult<$Schema, "Category">; diff --git a/packages/language/package.json b/packages/language/package.json index f5bef4ac3..5d809e7a1 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -49,6 +49,16 @@ "default": "./dist/utils.cjs" } }, + "./factory": { + "import": { + "types": "./dist/factory.d.ts", + "default": "./dist/factory.js" + }, + "require": { + "types": "./dist/factory.d.cts", + "default": "./dist/factory.cjs" + } + }, "./package.json": { "import": "./package.json", "require": "./package.json" diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index 4f473ed78..cb604c74a 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -120,7 +120,7 @@ function dbgenerated(expr: String?): Any { /** * Checks if the field value contains the search string. By default, the search is case-sensitive, and * "LIKE" operator is used to match. If `caseInSensitive` is true, "ILIKE" operator is used if - * supported, otherwise it still falls back to "LIKE" and delivers whatever the database's + * supported, otherwise it still falls back to "LIKE" and delivers whatever the database's * behavior is. */ function contains(field: String, search: String, caseInSensitive: Boolean?): Boolean { @@ -135,7 +135,7 @@ function contains(field: String, search: String, caseInSensitive: Boolean?): Boo /** * Checks the field value starts with the search string. By default, the search is case-sensitive, and * "LIKE" operator is used to match. If `caseInSensitive` is true, "ILIKE" operator is used if - * supported, otherwise it still falls back to "LIKE" and delivers whatever the database's + * supported, otherwise it still falls back to "LIKE" and delivers whatever the database's * behavior is. */ function startsWith(field: String, search: String, caseInSensitive: Boolean?): Boolean { @@ -144,7 +144,7 @@ function startsWith(field: String, search: String, caseInSensitive: Boolean?): B /** * Checks if the field value ends with the search string. By default, the search is case-sensitive, and * "LIKE" operator is used to match. If `caseInSensitive` is true, "ILIKE" operator is used if - * supported, otherwise it still falls back to "LIKE" and delivers whatever the database's + * supported, otherwise it still falls back to "LIKE" and delivers whatever the database's * behavior is. */ function endsWith(field: String, search: String, caseInSensitive: Boolean?): Boolean { diff --git a/packages/language/src/document.ts b/packages/language/src/document.ts index 9642e61d5..7426c606d 100644 --- a/packages/language/src/document.ts +++ b/packages/language/src/document.ts @@ -13,7 +13,7 @@ import path from 'node:path'; import { fileURLToPath } from 'node:url'; import { isDataModel, isDataSource, type Model } from './ast'; import { DB_PROVIDERS_SUPPORTING_LIST_TYPE, STD_LIB_MODULE_NAME } from './constants'; -import { createZModelServices } from './module'; +import { createZModelServices, type ZModelServices } from './module'; import { getAllFields, getDataModelAndTypeDefs, @@ -32,8 +32,10 @@ import type { ZModelFormatter } from './zmodel-formatter'; export async function loadDocument( fileName: string, additionalModelFiles: string[] = [], + mergeImports: boolean = true, ): Promise< - { success: true; model: Model; warnings: string[] } | { success: false; errors: string[]; warnings: string[] } + | { success: true; model: Model; warnings: string[]; services: ZModelServices } + | { success: false; errors: string[]; warnings: string[] } > { const { ZModelLanguage: services } = createZModelServices(false); const extensions = services.LanguageMetaData.fileExtensions; @@ -121,17 +123,19 @@ export async function loadDocument( const model = document.parseResult.value as Model; - // merge all declarations into the main document - const imported = mergeImportsDeclarations(langiumDocuments, model); + if (mergeImports) { + // merge all declarations into the main document + const imported = mergeImportsDeclarations(langiumDocuments, model); - // remove imported documents - imported.forEach((model) => { - langiumDocuments.deleteDocument(model.$document!.uri); - services.shared.workspace.IndexManager.remove(model.$document!.uri); - }); + // remove imported documents + imported.forEach((model) => { + langiumDocuments.deleteDocument(model.$document!.uri); + services.shared.workspace.IndexManager.remove(model.$document!.uri); + }); + } // extra validation after merging imported declarations - const additionalErrors = validationAfterImportMerge(model); + const additionalErrors = mergeImports === true ? validationAfterImportMerge(model) : []; if (additionalErrors.length > 0) { return { success: false, @@ -143,6 +147,7 @@ export async function loadDocument( return { success: true, model: document.parseResult.value as Model, + services, warnings, }; } diff --git a/packages/language/src/factory/ast-factory.ts b/packages/language/src/factory/ast-factory.ts new file mode 100644 index 000000000..7618f0738 --- /dev/null +++ b/packages/language/src/factory/ast-factory.ts @@ -0,0 +1,52 @@ +import { type AstNode } from '../ast'; + +export type ContainerProps = { + $container: T; + $containerProperty?: string; + $containerIndex?: number; +}; + +type NodeFactoriesFor = { + [K in keyof N as {} extends Pick ? never : K]: N[K] extends (infer U)[] + ? (AstFactory | U)[] + : AstFactory | N[K]; +} & { + [K in keyof N as {} extends Pick ? K : never]?: N[K] extends (infer U)[] + ? (AstFactory | U)[] + : AstFactory | N[K]; +}; + +export abstract class AstFactory { + node = {} as T; + constructor({ type, node }: { type: T['$type']; node?: Partial }) { + (this.node as any).$type = type; + if (node) { + this.update(node); + } + } + setContainer(container: T['$container']) { + (this.node as any).$container = container; + return this; + } + + get(params?: ContainerProps): T { + if (params) this.update(params as any); + return this.node; + } + update(nodeArg: Partial>): T { + const keys = Object.keys(nodeArg as object); + keys.forEach((key) => { + const child = (nodeArg as any)[key]; + if (child instanceof AstFactory) { + (this.node as any)[key] = child.get({ $container: this.node as any }); + } else if (Array.isArray(child)) { + (this.node as any)[key] = child.map((item: any) => + item instanceof AstFactory ? item.get({ $container: this.node as any }) : item, + ); + } else { + (this.node as any)[key] = child; + } + }); + return this.node; + } +} diff --git a/packages/language/src/factory/attribute.ts b/packages/language/src/factory/attribute.ts new file mode 100644 index 000000000..b59e35ef1 --- /dev/null +++ b/packages/language/src/factory/attribute.ts @@ -0,0 +1,281 @@ +import { AstFactory } from './ast-factory'; +import { + Attribute, + AttributeArg, + AttributeParam, + AttributeParamType, + DataFieldAttribute, + DataModelAttribute, + Expression, + InternalAttribute, + TypeDeclaration, + type Reference, + type RegularID, +} from '../ast'; +import { ExpressionBuilder } from './expression'; + +export class DataFieldAttributeFactory extends AstFactory { + args: AttributeArgFactory[] = []; + decl?: Reference; + constructor() { + super({ type: DataFieldAttribute, node: { args: [] } }); + } + setDecl(decl: Attribute) { + if (!decl) { + throw new Error('Attribute declaration is required'); + } + this.decl = { + $refText: decl.name, + ref: decl, + }; + this.update({ + decl: this.decl, + }); + return this; + } + addArg(builder: (b: ExpressionBuilder) => AstFactory, name?: string) { + const factory = new AttributeArgFactory().setValue(builder); + if (name) { + factory.setName(name); + } + this.args.push(factory); + this.update({ + args: this.args, + }); + return this; + } +} + +export class DataModelAttributeFactory extends AstFactory { + args: AttributeArgFactory[] = []; + decl?: Reference; + constructor() { + super({ type: DataModelAttribute, node: { args: [] } }); + } + setDecl(decl: Attribute) { + if (!decl) { + throw new Error('Attribute declaration is required'); + } + this.decl = { + $refText: decl.name, + ref: decl, + }; + this.update({ + decl: this.decl, + }); + return this; + } + addArg(builder: (b: ExpressionBuilder) => AstFactory, name?: string) { + const factory = new AttributeArgFactory().setValue(builder); + if (name) { + factory.setName(name); + } + this.args.push(factory); + this.update({ + args: this.args, + }); + return this; + } +} + +export class AttributeArgFactory extends AstFactory { + name?: RegularID = ''; + value?: AstFactory; + + constructor() { + super({ type: AttributeArg }); + } + + setName(name: RegularID) { + this.name = name; + this.update({ + name: this.name, + }); + return this; + } + + setValue(builder: (b: ExpressionBuilder) => AstFactory) { + this.value = builder(ExpressionBuilder()); + this.update({ + value: this.value, + }); + return this; + } +} + +export class InternalAttributeFactory extends AstFactory { + decl?: Reference; + args: AttributeArgFactory[] = []; + + constructor() { + super({ type: InternalAttribute, node: { args: [] } }); + } + + setDecl(decl: Attribute) { + this.decl = { + $refText: decl.name, + ref: decl, + }; + this.update({ + decl: this.decl, + }); + return this; + } + + addArg(builder: (b: ExpressionBuilder) => AstFactory, name?: string) { + const factory = new AttributeArgFactory().setValue(builder); + if (name) { + factory.setName(name); + } + this.args.push(factory); + this.update({ + args: this.args, + }); + return this; + } +} + +export class AttributeParamFactory extends AstFactory { + attributes: InternalAttributeFactory[] = []; + comments: string[] = []; + default?: boolean; + name?: RegularID; + type?: AttributeParamTypeFactory; + + constructor() { + super({ + type: AttributeParam, + node: { + comments: [], + attributes: [], + }, + }); + } + + addAttribute(builder: (b: InternalAttributeFactory) => InternalAttributeFactory) { + this.attributes.push(builder(new InternalAttributeFactory())); + this.update({ + attributes: this.attributes, + }); + return this; + } + + setComments(comments: string[]) { + this.comments = comments; + this.update({ + comments: this.comments, + }); + return this; + } + + setDefault(defaultValue: boolean) { + this.default = defaultValue; + this.update({ + default: this.default, + }); + return this; + } + + setName(name: string) { + this.name = name; + this.update({ + name: this.name, + }); + return this; + } + + setType(builder: (b: AttributeParamTypeFactory) => AttributeParamTypeFactory) { + this.type = builder(new AttributeParamTypeFactory()); + this.update({ + type: this.type, + }); + return this; + } +} + +export class AttributeParamTypeFactory extends AstFactory { + array?: boolean; + optional?: boolean; + reference?: Reference; + type?: AttributeParamType['type']; + constructor() { + super({ type: AttributeParamType }); + } + setArray(array: boolean) { + this.array = array; + this.update({ + array: this.array, + }); + return this; + } + + setOptional(optional: boolean) { + this.optional = optional; + this.update({ + optional: this.optional, + }); + return this; + } + + setReference(reference: TypeDeclaration) { + this.reference = { + $refText: reference.name, + ref: reference, + }; + this.update({ + reference: this.reference, + }); + return this; + } + + setType(type: AttributeParamType['type']) { + this.type = type; + this.update({ + type: this.type, + }); + return this; + } +} + +export class AttributeFactory extends AstFactory { + name?: string; + comments: string[] = []; + attributes: InternalAttributeFactory[] = []; + params: AttributeParamFactory[] = []; + + constructor() { + super({ type: Attribute, node: { comments: [], attributes: [], params: [] } }); + } + + setName(name: string) { + this.name = name; + this.update({ + name: this.name, + }); + return this; + } + + setComments(comments: string[]) { + this.comments = comments; + this.update({ + comments: this.comments, + }); + return this; + } + + addAttribute(builder: (b: InternalAttributeFactory) => InternalAttributeFactory) { + this.attributes.push(builder(new InternalAttributeFactory())); + this.update({ + attributes: this.attributes, + }); + return this; + } + + addParam(builder: (b: AttributeParamFactory) => AttributeParamFactory) { + this.params.push(builder(new AttributeParamFactory())); + this.update({ + params: this.params, + }); + return this; + } +} diff --git a/packages/language/src/factory/declaration.ts b/packages/language/src/factory/declaration.ts new file mode 100644 index 000000000..a6f772a20 --- /dev/null +++ b/packages/language/src/factory/declaration.ts @@ -0,0 +1,373 @@ +import { AstFactory } from './ast-factory'; +import { AbstractDeclaration, type Reference } from '../ast'; +import { + type BuiltinType, + DataField, + DataFieldType, + DataModel, + Enum, + EnumField, + LiteralExpr, + Model, + ModelImport, + type RegularID, + type RegularIDWithTypeNames, + TypeDeclaration, + type TypeDef, + UnsupportedFieldType, +} from '../generated/ast'; +import { AttributeFactory, DataFieldAttributeFactory, DataModelAttributeFactory } from './attribute'; +import { ExpressionBuilder } from './expression'; +export const DeclarationBuilder = () => + ({ + get Attribute() { + return new AttributeFactory(); + }, + get DataModel() { + return new DataModelFactory(); + }, + get DataSource(): any { + throw new Error('DataSource is not implemented'); + }, + get Enum() { + return new EnumFactory(); + }, + get FunctionDecl(): any { + throw new Error('FunctionDecl is not implemented'); + }, + get GeneratorDecl(): any { + throw new Error('GeneratorDecl is not implemented'); + }, + get Plugin(): any { + throw new Error('Plugin is not implemented'); + }, + get Procedure(): any { + throw new Error('Procedure is not implemented'); + }, + get TypeDef(): any { + throw new Error('TypeDef is not implemented'); + }, + }) satisfies DeclarationBuilderType; +type DeclarationBuilderType = { + [K in T['$type']]: AstFactory>; +}; +type DeclarationBuilderMap = ReturnType; + +export type DeclarationBuilder = Pick< + DeclarationBuilderMap, + Extract +>; + +export class DataModelFactory extends AstFactory { + attributes: DataModelAttributeFactory[] = []; + baseModel?: Reference; + comments: string[] = []; + fields: DataFieldFactory[] = []; + isView?: boolean; + mixins: Reference[] = []; + name?: RegularID; + + constructor() { + super({ + type: DataModel, + node: { + attributes: [], + comments: [], + fields: [], + mixins: [], + }, + }); + } + + addAttribute(builder: (attr: DataModelAttributeFactory) => DataModelAttributeFactory) { + this.attributes.push(builder(new DataModelAttributeFactory()).setContainer(this.node)); + this.update({ + attributes: this.attributes, + }); + return this; + } + + setBaseModel(model: Reference) { + this.baseModel = model; + this.update({ + baseModel: this.baseModel, + }); + return this; + } + + setComments(comments: string[]) { + this.comments = comments; + this.update({ + comments: this.comments, + }); + return this; + } + + addComment(comment: string) { + this.comments.push(comment); + this.update({ + comments: this.comments, + }); + return this; + } + + addField(builder: (field: DataFieldFactory) => DataFieldFactory) { + this.fields.push(builder(new DataFieldFactory()).setContainer(this.node)); + this.update({ + fields: this.fields, + }); + return this; + } + + setIsView(isView: boolean) { + this.isView = isView; + this.update({ + isView: this.isView, + }); + return this; + } + + addMixin(mixin: Reference) { + this.mixins.push(mixin); + this.update({ + mixins: this.mixins, + }); + return this; + } + + setName(name: string) { + this.name = name; + this.update({ + name: this.name, + }); + return this; + } +} + +export class DataFieldFactory extends AstFactory { + attributes: DataFieldAttributeFactory[] = []; + comments: string[] = []; + name?: string; + type?: DataFieldTypeFactory; + + constructor() { + super({ type: DataField, node: { attributes: [], comments: [] } }); + } + + addAttribute( + builder: ((attr: DataFieldAttributeFactory) => DataFieldAttributeFactory) | DataFieldAttributeFactory, + ) { + if (builder instanceof DataFieldAttributeFactory) { + builder.setContainer(this.node); + this.attributes.push(builder); + } else { + const attr = builder(new DataFieldAttributeFactory()); + attr.setContainer(this.node); + this.attributes.push(attr); + } + this.update({ + attributes: this.attributes, + }); + return this; + } + + setComments(comments: string[]) { + this.comments = comments; + this.update({ + comments: this.comments, + }); + return this; + } + + setName(name: string) { + this.name = name; + this.update({ + name: this.name, + }); + return this; + } + + setType(builder: (type: DataFieldTypeFactory) => DataFieldTypeFactory) { + this.type = builder(new DataFieldTypeFactory()).setContainer(this.node); + this.update({ + type: this.type, + }); + return this; + } +} + +export class DataFieldTypeFactory extends AstFactory { + array?: boolean; + optional?: boolean; + reference?: Reference; + type?: BuiltinType; + unsupported?: UnsupportedFieldTypeFactory; + + constructor() { + super({ type: DataFieldType }); + } + + setArray(array: boolean) { + this.array = array; + this.update({ + array: this.array, + }); + return this; + } + + setOptional(optional: boolean) { + this.optional = optional; + this.update({ + optional: this.optional, + }); + return this; + } + + setReference(reference: TypeDeclaration) { + this.reference = { + $refText: reference.name, + ref: reference, + }; + this.update({ + reference: this.reference, + }); + return this; + } + + setType(type: BuiltinType) { + this.type = type; + this.update({ + type: this.type, + }); + return this; + } + + setUnsupported(builder: (a: UnsupportedFieldTypeFactory) => UnsupportedFieldTypeFactory) { + this.unsupported = builder(new UnsupportedFieldTypeFactory()).setContainer(this.node); + this.update({ + unsupported: this.unsupported, + }); + return this; + } +} + +export class UnsupportedFieldTypeFactory extends AstFactory { + value?: AstFactory; + constructor() { + super({ type: UnsupportedFieldType }); + } + setValue(builder: (value: ExpressionBuilder) => AstFactory) { + this.value = builder(ExpressionBuilder()); + this.update({ + value: this.value!, + }); + return this; + } +} + +export class ModelFactory extends AstFactory { + declarations: AstFactory[] = []; + imports: ModelImportFactory[] = []; + constructor() { + super({ type: Model, node: { declarations: [], imports: [] } }); + } + addImport(builder: (b: ModelImportFactory) => ModelImportFactory) { + this.imports.push(builder(new ModelImportFactory()).setContainer(this.node)); + this.update({ + imports: this.imports, + }); + return this; + } + addDeclaration(builder: (b: DeclarationBuilder) => AstFactory) { + this.declarations.push(builder(DeclarationBuilder()).setContainer(this.node)); + this.update({ + declarations: this.declarations, + }); + return this; + } +} + +export class ModelImportFactory extends AstFactory { + path?: string | undefined; + + constructor() { + super({ type: ModelImport }); + } + + setPath(path: string) { + this.path = path; + this.update({ + path: this.path, + }); + return this; + } +} + +export class EnumFactory extends AstFactory { + name?: string; + comments: string[] = []; + fields: EnumFieldFactory[] = []; + attributes: DataModelAttributeFactory[] = []; + + constructor() { + super({ type: Enum, node: { comments: [], fields: [], attributes: [] } }); + } + + addField(builder: (b: EnumFieldFactory) => EnumFieldFactory) { + this.fields.push(builder(new EnumFieldFactory()).setContainer(this.node)); + this.update({ + fields: this.fields, + }); + return this; + } + + addAttribute(builder: (b: DataModelAttributeFactory) => DataModelAttributeFactory) { + this.attributes.push(builder(new DataModelAttributeFactory()).setContainer(this.node)); + this.update({ + attributes: this.attributes, + }); + return this; + } + + setName(name: string) { + this.name = name; + this.update({ + name: this.name, + }); + return this; + } +} + +export class EnumFieldFactory extends AstFactory { + name?: RegularIDWithTypeNames; + comments: string[] = []; + attributes: DataFieldAttributeFactory[] = []; + + constructor() { + super({ type: EnumField, node: { comments: [], attributes: [] } }); + } + + setName(name: RegularIDWithTypeNames) { + this.name = name; + this.update({ + name: this.name, + }); + return this; + } + + addAttribute(builder: (b: DataFieldAttributeFactory) => DataFieldAttributeFactory) { + this.attributes.push(builder(new DataFieldAttributeFactory()).setContainer(this.node)); + this.update({ + attributes: this.attributes, + }); + return this; + } + + addComment(comment: string) { + this.comments.push(comment); + this.update({ + comments: this.comments, + }); + return this; + } +} diff --git a/packages/language/src/factory/expression.ts b/packages/language/src/factory/expression.ts new file mode 100644 index 000000000..19fe16af1 --- /dev/null +++ b/packages/language/src/factory/expression.ts @@ -0,0 +1,308 @@ +import type { Reference } from 'langium'; +import { AstFactory } from './ast-factory'; +import { + Argument, + ArrayExpr, + BinaryExpr, + FieldInitializer, + FunctionDecl, + InvocationExpr, + MemberAccessExpr, + MemberAccessTarget, + ObjectExpr, + ReferenceArg, + ReferenceExpr, + ReferenceTarget, + UnaryExpr, + type Expression, + type RegularID, +} from '../ast'; +import { + BooleanLiteralFactory, + NullExprFactory, + NumberLiteralFactory, + StringLiteralFactory, + ThisExprFactory, +} from './primitives'; + +export const ExpressionBuilder = () => + ({ + get ArrayExpr() { + return new ArrayExprFactory(); + }, + get BinaryExpr() { + return new BinaryExprFactory(); + }, + get BooleanLiteral() { + return new BooleanLiteralFactory(); + }, + get InvocationExpr() { + return new InvocationExprFactory(); + }, + get MemberAccessExpr() { + return new MemberAccessExprFactory(); + }, + get NullExpr() { + return new NullExprFactory(); + }, + get NumberLiteral() { + return new NumberLiteralFactory(); + }, + get ObjectExpr() { + return new ObjectExprFactory(); + }, + get ReferenceExpr() { + return new ReferenceExprFactory(); + }, + get StringLiteral() { + return new StringLiteralFactory(); + }, + get ThisExpr() { + return new ThisExprFactory(); + }, + get UnaryExpr() { + return new UnaryExprFactory(); + }, + }) satisfies ExpressionBuilderType; +type ExpressionBuilderType = { + [K in T['$type']]: AstFactory>; +}; + +type ExpressionFactoryMap = ReturnType; + +export type ExpressionBuilder = Pick< + ExpressionFactoryMap, + Extract +>; + +export class UnaryExprFactory extends AstFactory { + operand?: AstFactory; + + constructor() { + super({ type: UnaryExpr, node: { operator: '!' } }); + } + + setOperand(builder: (a: ExpressionBuilder) => AstFactory) { + this.operand = builder(ExpressionBuilder()); + this.update({ + operand: this.operand, + }); + return this; + } +} + +export class ReferenceExprFactory extends AstFactory { + target?: Reference; + args: ReferenceArgFactory[] = []; + + constructor() { + super({ type: ReferenceExpr, node: { args: [] } }); + } + + setTarget(target: ReferenceTarget) { + this.target = { + $refText: target.name, + ref: target, + }; + this.update({ + target: this.target, + }); + return this; + } + + addArg(builder: (a: ExpressionBuilder) => AstFactory, name?: string) { + const arg = new ReferenceArgFactory().setValue(builder); + if (name) { + arg.setName(name); + } + this.args.push(arg); + this.update({ + args: this.args, + }); + return this; + } +} + +export class ReferenceArgFactory extends AstFactory { + name?: string; + value?: AstFactory; + + constructor() { + super({ type: ReferenceArg }); + } + + setName(name: string) { + this.name = name; + this.update({ + name: this.name, + }); + return this; + } + + setValue(builder: (a: ExpressionBuilder) => AstFactory) { + this.value = builder(ExpressionBuilder()); + this.update({ + value: this.value, + }); + return this; + } +} + +export class MemberAccessExprFactory extends AstFactory { + member?: Reference; + operand?: AstFactory; + + constructor() { + super({ type: MemberAccessExpr }); + } + + setMember(target: Reference) { + this.member = target; + this.update({ + member: this.member, + }); + return this; + } + + setOperand(builder: (b: ExpressionBuilder) => AstFactory) { + this.operand = builder(ExpressionBuilder()); + this.update({ + operand: this.operand, + }); + return this; + } +} + +export class ObjectExprFactory extends AstFactory { + fields: FieldInitializerFactory[] = []; + + constructor() { + super({ type: ObjectExpr, node: { fields: [] } }); + } + + addField(builder: (b: FieldInitializerFactory) => FieldInitializerFactory) { + this.fields.push(builder(new FieldInitializerFactory())); + this.update({ + fields: this.fields, + }); + return this; + } +} + +export class FieldInitializerFactory extends AstFactory { + name?: RegularID; + value?: AstFactory; + + constructor() { + super({ type: FieldInitializer }); + } + + setName(name: RegularID) { + this.name = name; + this.update({ + name: this.name!, + }); + return this; + } + + setValue(builder: (a: ExpressionBuilder) => AstFactory) { + this.value = builder(ExpressionBuilder()); + this.update({ + value: this.value!, + }); + return this; + } +} + +export class InvocationExprFactory extends AstFactory { + args: ArgumentFactory[] = []; + function?: Reference; + + constructor() { + super({ type: InvocationExpr, node: { args: [] } }); + } + + addArg(builder: (arg: ArgumentFactory) => ArgumentFactory) { + this.args.push(builder(new ArgumentFactory())); + this.update({ + args: this.args, + }); + return this; + } + + setFunction(value: FunctionDecl) { + this.function = { + $refText: value.name, + ref: value, + }; + this.update({ + function: this.function!, + }); + return this; + } +} + +export class ArgumentFactory extends AstFactory { + value?: AstFactory; + + constructor() { + super({ type: Argument }); + } + + setValue(builder: (a: ExpressionBuilder) => AstFactory) { + this.value = builder(ExpressionBuilder()); + this.update({ + value: this.value!, + }); + return this; + } +} + +export class ArrayExprFactory extends AstFactory { + items: AstFactory[] = []; + + constructor() { + super({ type: ArrayExpr, node: { items: [] } }); + } + + addItem(builder: (a: ExpressionBuilder) => AstFactory) { + this.items.push(builder(ExpressionBuilder())); + this.update({ + items: this.items, + }); + return this; + } +} + +export class BinaryExprFactory extends AstFactory { + operator?: BinaryExpr['operator']; + right?: AstFactory; + left?: AstFactory; + // TODO: add support for CollectionPredicateBinding + + constructor() { + super({ type: BinaryExpr }); + } + + setOperator(operator: BinaryExpr['operator']) { + this.operator = operator; + this.update({ + operator: this.operator!, + }); + return this; + } + setRight(builder: (arg: ExpressionBuilder) => AstFactory) { + this.right = builder(ExpressionBuilder()); + this.update({ + right: this.right!, + }); + return this; + } + setLeft(builder: (arg: ExpressionBuilder) => AstFactory) { + this.left = builder(ExpressionBuilder()); + this.update({ + left: this.left!, + }); + return this; + } +} diff --git a/packages/language/src/factory/index.ts b/packages/language/src/factory/index.ts new file mode 100644 index 000000000..1ea2a286b --- /dev/null +++ b/packages/language/src/factory/index.ts @@ -0,0 +1,5 @@ +export * from './ast-factory'; +export * from './primitives'; +export * from './expression'; +export * from './declaration'; +export * from './attribute'; diff --git a/packages/language/src/factory/primitives.ts b/packages/language/src/factory/primitives.ts new file mode 100644 index 000000000..e97310d54 --- /dev/null +++ b/packages/language/src/factory/primitives.ts @@ -0,0 +1,61 @@ +import { AstFactory } from './ast-factory'; +import { BooleanLiteral, NullExpr, NumberLiteral, StringLiteral, ThisExpr } from '../ast'; + +export class ThisExprFactory extends AstFactory { + constructor() { + super({ type: ThisExpr, node: { value: 'this' } }); + } +} + +export class NullExprFactory extends AstFactory { + constructor() { + super({ type: NullExpr, node: { value: 'null' } }); + } +} + +export class NumberLiteralFactory extends AstFactory { + value?: number | string; + + constructor() { + super({ type: NumberLiteral }); + } + + setValue(value: number | string) { + this.value = value; + this.update({ + value: this.value.toString(), + }); + return this; + } +} + +export class StringLiteralFactory extends AstFactory { + value?: string; + + constructor() { + super({ type: StringLiteral }); + } + + setValue(value: string) { + this.value = value; + this.update({ + value: this.value, + }); + return this; + } +} +export class BooleanLiteralFactory extends AstFactory { + value?: boolean; + + constructor() { + super({ type: BooleanLiteral }); + } + + setValue(value: boolean) { + this.value = value; + this.update({ + value: this.value, + }); + return this; + } +} diff --git a/packages/language/src/validators/datamodel-validator.ts b/packages/language/src/validators/datamodel-validator.ts index 6c5d18ffd..d2fcd155d 100644 --- a/packages/language/src/validators/datamodel-validator.ts +++ b/packages/language/src/validators/datamodel-validator.ts @@ -44,13 +44,15 @@ export default class DataModelValidator implements AstValidator { const uniqueFields = allFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@unique')); const modelLevelIds = getModelIdFields(dm); const modelUniqueFields = getModelUniqueFields(dm); + const ignore = hasAttribute(dm, '@@ignore'); if ( !dm.isView && idFields.length === 0 && modelLevelIds.length === 0 && uniqueFields.length === 0 && - modelUniqueFields.length === 0 + modelUniqueFields.length === 0 && + !ignore ) { accept( 'error', diff --git a/packages/language/src/zmodel-code-generator.ts b/packages/language/src/zmodel-code-generator.ts index 1e0366ede..e68ba7735 100644 --- a/packages/language/src/zmodel-code-generator.ts +++ b/packages/language/src/zmodel-code-generator.ts @@ -28,6 +28,7 @@ import { LiteralExpr, MemberAccessExpr, Model, + ModelImport, NullExpr, NumberLiteral, ObjectExpr, @@ -70,7 +71,7 @@ function gen(name: string) { */ export class ZModelCodeGenerator { private readonly options: ZModelCodeOptions; - + private readonly quote: string; constructor(options?: Partial) { this.options = { binaryExprNumberOfSpaces: options?.binaryExprNumberOfSpaces ?? 1, @@ -78,6 +79,7 @@ export class ZModelCodeGenerator { indent: options?.indent ?? 4, quote: options?.quote ?? 'single', }; + this.quote = this.options.quote === 'double' ? '"' : "'"; } /** @@ -91,9 +93,16 @@ export class ZModelCodeGenerator { return handler.value.call(this, ast); } + private quotedStr(val: string): string { + const trimmedVal = val.replace(new RegExp(`(? this.generate(d)).join('\n\n'); + return `${ast.imports.map((d) => this.generate(d)).join('\n')}${ast.imports.length > 0 ? '\n\n' : ''}${ast.declarations + .map((d) => this.generate(d)) + .join('\n\n')}`; } @gen(DataSource) @@ -103,18 +112,39 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')} }`; } + @gen(ModelImport) + private _generateModelImport(ast: ModelImport) { + return `import ${this.quotedStr(ast.path)}`; + } + @gen(Enum) private _generateEnum(ast: Enum) { - return `enum ${ast.name} { -${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')} + const comments = `${ast.comments.join('\n')}\n`; + return `${ast.comments.length > 0 ? comments : ''}enum ${ast.name} { +${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')}${ + ast.attributes.length > 0 + ? '\n\n' + ast.attributes.map((x) => this.indent + this.generate(x)).join('\n') + : '' + } }`; } @gen(EnumField) private _generateEnumField(ast: EnumField) { - return `${ast.name}${ + const fieldLine = `${ast.name}${ ast.attributes.length > 0 ? ' ' + ast.attributes.map((x) => this.generate(x)).join(' ') : '' }`; + + if (ast.comments.length === 0) { + return fieldLine; + } + + // Build comment block with proper indentation: + // - First comment: no indent (caller adds it via `this.indent + this.generate(x)`) + // - Subsequent comments: add indent + // - Field line: add indent (since it comes after the comment block) + const commentLines = ast.comments.map((c, i) => (i === 0 ? c : this.indent + c)); + return `${commentLines.join('\n')}\n${this.indent}${fieldLine}`; } @gen(GeneratorDecl) @@ -159,8 +189,10 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')} @gen(DataModel) private _generateDataModel(ast: DataModel) { - return `${ast.isView ? 'view' : 'model'} ${ast.name}${ - ast.mixins.length > 0 ? ' mixes ' + ast.mixins.map((x) => x.$refText).join(', ') : '' + const comments = `${ast.comments.join('\n')}\n`; + + return `${ast.comments.length > 0 ? comments : ''}${ast.isView ? 'view' : 'model'} ${ast.name}${ + ast.mixins.length > 0 ? ' with ' + ast.mixins.map((x) => x.$refText).join(', ') : '' } { ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')}${ ast.attributes.length > 0 @@ -172,9 +204,20 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')}${ @gen(DataField) private _generateDataField(ast: DataField) { - return `${ast.name} ${this.fieldType(ast.type)}${ + const fieldLine = `${ast.name} ${this.fieldType(ast.type)}${ ast.attributes.length > 0 ? ' ' + ast.attributes.map((x) => this.generate(x)).join(' ') : '' }`; + + if (ast.comments.length === 0) { + return fieldLine; + } + + // Build comment block with proper indentation: + // - First comment: no indent (caller adds it via `this.indent + this.generate(x)`) + // - Subsequent comments: add indent + // - Field line: add indent (since it comes after the comment block) + const commentLines = ast.comments.map((c, i) => (i === 0 ? c : this.indent + c)); + return `${commentLines.join('\n')}\n${this.indent}${fieldLine}`; } private fieldType(type: DataFieldType) { @@ -226,7 +269,7 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')}${ @gen(StringLiteral) private _generateLiteralExpr(ast: LiteralExpr) { - return this.options.quote === 'single' ? `'${ast.value}'` : `"${ast.value}"`; + return this.quotedStr(ast.value as string); } @gen(NumberLiteral) @@ -271,7 +314,7 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')}${ @gen(ReferenceArg) private _generateReferenceArg(ast: ReferenceArg) { - return `${ast.name}:${this.generate(ast.value)}`; + return `${ast.name}: ${this.generate(ast.value)}`; } @gen(MemberAccessExpr) diff --git a/packages/language/tsup.config.ts b/packages/language/tsup.config.ts index 0d5d2b6c4..48282a08c 100644 --- a/packages/language/tsup.config.ts +++ b/packages/language/tsup.config.ts @@ -5,6 +5,7 @@ export default defineConfig({ index: 'src/index.ts', ast: 'src/ast.ts', utils: 'src/utils.ts', + factory: 'src/factory/index.ts', }, outDir: 'dist', splitting: false, diff --git a/packages/orm/src/client/client-impl.ts b/packages/orm/src/client/client-impl.ts index fc8f92c7c..acf888f8a 100644 --- a/packages/orm/src/client/client-impl.ts +++ b/packages/orm/src/client/client-impl.ts @@ -75,6 +75,11 @@ export class ClientImpl { ...this.$options.functions, }; + if (!baseClient) { + // validate computed fields configuration once for the root client + this.validateComputedFieldsConfig(); + } + // here we use kysely's props constructor so we can pass a custom query executor if (baseClient) { this.kyselyProps = { @@ -139,6 +144,39 @@ export class ClientImpl { return new ClientImpl(this.schema, this.$options, this, executor); } + /** + * Validates that all computed fields in the schema have corresponding configurations. + */ + private validateComputedFieldsConfig() { + const computedFieldsConfig = + 'computedFields' in this.$options + ? (this.$options.computedFields as Record | undefined) + : undefined; + + for (const [modelName, modelDef] of Object.entries(this.$schema.models)) { + if (modelDef.computedFields) { + for (const fieldName of Object.keys(modelDef.computedFields)) { + const modelConfig = computedFieldsConfig?.[modelName]; + const fieldConfig = modelConfig?.[fieldName]; + // Check if the computed field has a configuration + if (fieldConfig === null || fieldConfig === undefined) { + throw createConfigError( + `Computed field "${fieldName}" in model "${modelName}" does not have a configuration. ` + + `Please provide an implementation in the computedFields option.`, + ); + } + // Check that the configuration is a function + if (typeof fieldConfig !== 'function') { + throw createConfigError( + `Computed field "${fieldName}" in model "${modelName}" has an invalid configuration: ` + + `expected a function but received ${typeof fieldConfig}.`, + ); + } + } + } + } + } + // overload for interactive transaction $transaction( callback: (tx: ClientContract) => Promise, diff --git a/packages/sdk/src/ts-schema-generator.ts b/packages/sdk/src/ts-schema-generator.ts index 413131269..ac6fcf00f 100644 --- a/packages/sdk/src/ts-schema-generator.ts +++ b/packages/sdk/src/ts-schema-generator.ts @@ -1359,18 +1359,18 @@ export class TsSchemaGenerator { ts.factory.createImportDeclaration( undefined, ts.factory.createImportClause( - false, + true, undefined, ts.factory.createNamedImports([ ts.factory.createImportSpecifier( - true, + false, undefined, ts.factory.createIdentifier(`ModelResult as $ModelResult`), ), ...(model.declarations.some(isTypeDef) ? [ ts.factory.createImportSpecifier( - true, + false, undefined, ts.factory.createIdentifier(`TypeDefResult as $TypeDefResult`), ), diff --git a/packages/testtools/src/client.ts b/packages/testtools/src/client.ts index 89148c405..69513eeb4 100644 --- a/packages/testtools/src/client.ts +++ b/packages/testtools/src/client.ts @@ -104,6 +104,11 @@ type ExtraTestClientOptions = { globPattern: string; destination: string; }[]; + + /** + * Computed fields configuration for tests. + */ + computedFields?: import('@zenstackhq/orm').ComputedFieldsOptions; }; export type CreateTestClientOptions = Omit, 'dialect'> & diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 3f519e5c3..f6e6a972e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -201,6 +201,9 @@ importers: '@zenstackhq/orm': specifier: workspace:* version: link:../orm + '@zenstackhq/schema': + specifier: workspace:* + version: link:../schema '@zenstackhq/sdk': specifier: workspace:* version: link:../sdk @@ -12980,7 +12983,7 @@ snapshots: eslint: 9.29.0(jiti@2.6.1) eslint-import-resolver-node: 0.3.9 eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)) - eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@9.29.0(jiti@2.6.1)) + eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)) eslint-plugin-jsx-a11y: 6.10.2(eslint@9.29.0(jiti@2.6.1)) eslint-plugin-react: 7.37.5(eslint@9.29.0(jiti@2.6.1)) eslint-plugin-react-hooks: 7.0.1(eslint@9.29.0(jiti@2.6.1)) @@ -13013,7 +13016,7 @@ snapshots: tinyglobby: 0.2.15 unrs-resolver: 1.11.1 optionalDependencies: - eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@9.29.0(jiti@2.6.1)) + eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)) transitivePeerDependencies: - supports-color @@ -13028,7 +13031,7 @@ snapshots: transitivePeerDependencies: - supports-color - eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@9.29.0(jiti@2.6.1)): + eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)): dependencies: '@rtsao/scc': 1.1.0 array-includes: 3.1.9 diff --git a/samples/next.js/zenstack/models.ts b/samples/next.js/zenstack/models.ts index 3314c7d48..d878eac47 100644 --- a/samples/next.js/zenstack/models.ts +++ b/samples/next.js/zenstack/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { type SchemaType as $Schema } from "./schema-lite"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; /** * User model */ diff --git a/samples/nuxt/zenstack/models.ts b/samples/nuxt/zenstack/models.ts index 3314c7d48..d878eac47 100644 --- a/samples/nuxt/zenstack/models.ts +++ b/samples/nuxt/zenstack/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { type SchemaType as $Schema } from "./schema-lite"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; /** * User model */ diff --git a/samples/orm/zenstack/models.ts b/samples/orm/zenstack/models.ts index e2db380ea..b97ba05f6 100644 --- a/samples/orm/zenstack/models.ts +++ b/samples/orm/zenstack/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { schema as $schema, type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult, type TypeDefResult as $TypeDefResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult, TypeDefResult as $TypeDefResult } from "@zenstackhq/orm"; /** * User model */ diff --git a/samples/sveltekit/src/zenstack/models.ts b/samples/sveltekit/src/zenstack/models.ts index 3314c7d48..d878eac47 100644 --- a/samples/sveltekit/src/zenstack/models.ts +++ b/samples/sveltekit/src/zenstack/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { type SchemaType as $Schema } from "./schema-lite"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; /** * User model */ diff --git a/tests/e2e/apps/rally/zenstack/models.ts b/tests/e2e/apps/rally/zenstack/models.ts index e30e340b6..06e6ac4ff 100644 --- a/tests/e2e/apps/rally/zenstack/models.ts +++ b/tests/e2e/apps/rally/zenstack/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { schema as $schema, type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; export type Account = $ModelResult<$Schema, "Account">; export type User = $ModelResult<$Schema, "User">; export type VerificationToken = $ModelResult<$Schema, "VerificationToken">; diff --git a/tests/e2e/github-repos/cal.com/models.ts b/tests/e2e/github-repos/cal.com/models.ts index 26c9b777e..c8f35c09f 100644 --- a/tests/e2e/github-repos/cal.com/models.ts +++ b/tests/e2e/github-repos/cal.com/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { schema as $schema, type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; export type Host = $ModelResult<$Schema, "Host">; export type CalVideoSettings = $ModelResult<$Schema, "CalVideoSettings">; export type EventType = $ModelResult<$Schema, "EventType">; diff --git a/tests/e2e/github-repos/formbricks/models.ts b/tests/e2e/github-repos/formbricks/models.ts index 4c57997cf..740a0e041 100644 --- a/tests/e2e/github-repos/formbricks/models.ts +++ b/tests/e2e/github-repos/formbricks/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { schema as $schema, type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; /** * Represents a webhook endpoint for receiving survey-related events. * Webhooks can be configured to receive notifications about response creation, updates, and completion. diff --git a/tests/e2e/github-repos/trigger.dev/models.ts b/tests/e2e/github-repos/trigger.dev/models.ts index ba6e7cd16..70ee88aae 100644 --- a/tests/e2e/github-repos/trigger.dev/models.ts +++ b/tests/e2e/github-repos/trigger.dev/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { schema as $schema, type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; export type User = $ModelResult<$Schema, "User">; export type InvitationCode = $ModelResult<$Schema, "InvitationCode">; /** diff --git a/tests/e2e/orm/client-api/computed-fields.test.ts b/tests/e2e/orm/client-api/computed-fields.test.ts index c6470a720..1816854b8 100644 --- a/tests/e2e/orm/client-api/computed-fields.test.ts +++ b/tests/e2e/orm/client-api/computed-fields.test.ts @@ -3,6 +3,94 @@ import { sql } from 'kysely'; import { describe, expect, it } from 'vitest'; describe('Computed fields tests', () => { + it('throws error when computed field configuration is missing', async () => { + await expect( + createTestClient( + ` +model User { + id Int @id @default(autoincrement()) + name String + upperName String @computed +} +`, + { + // missing computedFields configuration + } as any, + ), + ).rejects.toThrow('Computed field "upperName" in model "User" does not have a configuration'); + }); + + it('throws error when computed field is missing from configuration', async () => { + await expect( + createTestClient( + ` +model User { + id Int @id @default(autoincrement()) + name String + upperName String @computed + lowerName String @computed +} +`, + { + computedFields: { + User: { + // only providing one of two computed fields + upperName: (eb: any) => eb.fn('upper', ['name']), + }, + }, + } as any, + ), + ).rejects.toThrow('Computed field "lowerName" in model "User" does not have a configuration'); + }); + + it('throws error when computed field configuration is not a function', async () => { + await expect( + createTestClient( + ` +model User { + id Int @id @default(autoincrement()) + name String + upperName String @computed +} +`, + { + computedFields: { + User: { + // providing a string instead of a function + upperName: 'not a function' as any, + }, + }, + } as any, + ), + ).rejects.toThrow( + 'Computed field "upperName" in model "User" has an invalid configuration: expected a function but received string', + ); + }); + + it('throws error when computed field configuration is a non-function object', async () => { + await expect( + createTestClient( + ` +model User { + id Int @id @default(autoincrement()) + name String + computed1 String @computed +} +`, + { + computedFields: { + User: { + // providing an object instead of a function + computed1: { key: 'value' } as any, + }, + }, + } as any, + ), + ).rejects.toThrow( + 'Computed field "computed1" in model "User" has an invalid configuration: expected a function but received object', + ); + }); + it('works with non-optional fields', async () => { const db = await createTestClient( ` @@ -102,6 +190,11 @@ model User { } `, { + computedFields: { + User: { + upperName: (eb: any) => eb.fn('upper', ['name']), + }, + }, extraSourceFiles: { main: ` import { ZenStackClient } from '@zenstackhq/orm'; @@ -169,6 +262,11 @@ model User { } `, { + computedFields: { + User: { + upperName: (eb: any) => eb.lit(null), + }, + }, extraSourceFiles: { main: ` import { ZenStackClient } from '@zenstackhq/orm'; diff --git a/tests/e2e/orm/plugin-infra/ext-query-args/models.ts b/tests/e2e/orm/plugin-infra/ext-query-args/models.ts index 7a605bdbc..996d95e2a 100644 --- a/tests/e2e/orm/plugin-infra/ext-query-args/models.ts +++ b/tests/e2e/orm/plugin-infra/ext-query-args/models.ts @@ -6,5 +6,5 @@ /* eslint-disable */ import { type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; export type User = $ModelResult<$Schema, "User">; diff --git a/tests/e2e/orm/schemas/auth-type/models.ts b/tests/e2e/orm/schemas/auth-type/models.ts index 75911fc76..2d78f9bbd 100644 --- a/tests/e2e/orm/schemas/auth-type/models.ts +++ b/tests/e2e/orm/schemas/auth-type/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult, type TypeDefResult as $TypeDefResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult, TypeDefResult as $TypeDefResult } from "@zenstackhq/orm"; export type Foo = $ModelResult<$Schema, "Foo">; export type Permission = $TypeDefResult<$Schema, "Permission">; export type Auth = $TypeDefResult<$Schema, "Auth">; diff --git a/tests/e2e/orm/schemas/basic/models.ts b/tests/e2e/orm/schemas/basic/models.ts index 733e7df68..39bd52fdf 100644 --- a/tests/e2e/orm/schemas/basic/models.ts +++ b/tests/e2e/orm/schemas/basic/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { schema as $schema, type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult, type TypeDefResult as $TypeDefResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult, TypeDefResult as $TypeDefResult } from "@zenstackhq/orm"; export type User = $ModelResult<$Schema, "User">; export type Post = $ModelResult<$Schema, "Post">; export type Comment = $ModelResult<$Schema, "Comment">; diff --git a/tests/e2e/orm/schemas/default-auth/models.ts b/tests/e2e/orm/schemas/default-auth/models.ts index 3f7f08e6b..624cfb8ec 100644 --- a/tests/e2e/orm/schemas/default-auth/models.ts +++ b/tests/e2e/orm/schemas/default-auth/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; export type User = $ModelResult<$Schema, "User">; export type Profile = $ModelResult<$Schema, "Profile">; export type Address = $ModelResult<$Schema, "Address">; diff --git a/tests/e2e/orm/schemas/delegate/models.ts b/tests/e2e/orm/schemas/delegate/models.ts index 0a4350d20..e1fe994bd 100644 --- a/tests/e2e/orm/schemas/delegate/models.ts +++ b/tests/e2e/orm/schemas/delegate/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; export type User = $ModelResult<$Schema, "User">; export type Comment = $ModelResult<$Schema, "Comment">; export type Asset = $ModelResult<$Schema, "Asset">; diff --git a/tests/e2e/orm/schemas/json/models.ts b/tests/e2e/orm/schemas/json/models.ts index 24ea716d8..bdbfabad1 100644 --- a/tests/e2e/orm/schemas/json/models.ts +++ b/tests/e2e/orm/schemas/json/models.ts @@ -6,5 +6,5 @@ /* eslint-disable */ import { type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; export type Foo = $ModelResult<$Schema, "Foo">; diff --git a/tests/e2e/orm/schemas/name-mapping/models.ts b/tests/e2e/orm/schemas/name-mapping/models.ts index 944ad9cb5..9c5a74865 100644 --- a/tests/e2e/orm/schemas/name-mapping/models.ts +++ b/tests/e2e/orm/schemas/name-mapping/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { schema as $schema, type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; export type User = $ModelResult<$Schema, "User">; export type Post = $ModelResult<$Schema, "Post">; export const Role = $schema.enums.Role.values; diff --git a/tests/e2e/orm/schemas/omit/models.ts b/tests/e2e/orm/schemas/omit/models.ts index 6636b4d54..5044cb91a 100644 --- a/tests/e2e/orm/schemas/omit/models.ts +++ b/tests/e2e/orm/schemas/omit/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; export type User = $ModelResult<$Schema, "User">; export type Post = $ModelResult<$Schema, "Post">; export type Base = $ModelResult<$Schema, "Base">; diff --git a/tests/e2e/orm/schemas/petstore/models.ts b/tests/e2e/orm/schemas/petstore/models.ts index dfa5b23e6..b55674cb0 100644 --- a/tests/e2e/orm/schemas/petstore/models.ts +++ b/tests/e2e/orm/schemas/petstore/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; export type User = $ModelResult<$Schema, "User">; export type Pet = $ModelResult<$Schema, "Pet">; export type Order = $ModelResult<$Schema, "Order">; diff --git a/tests/e2e/orm/schemas/procedures/models.ts b/tests/e2e/orm/schemas/procedures/models.ts index 9920c1011..f5d8f42dc 100644 --- a/tests/e2e/orm/schemas/procedures/models.ts +++ b/tests/e2e/orm/schemas/procedures/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { schema as $schema, type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult, type TypeDefResult as $TypeDefResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult, TypeDefResult as $TypeDefResult } from "@zenstackhq/orm"; export type User = $ModelResult<$Schema, "User">; export type Overview = $TypeDefResult<$Schema, "Overview">; export const Role = $schema.enums.Role.values; diff --git a/tests/e2e/orm/schemas/todo/models.ts b/tests/e2e/orm/schemas/todo/models.ts index 635b68dee..7d3c2008e 100644 --- a/tests/e2e/orm/schemas/todo/models.ts +++ b/tests/e2e/orm/schemas/todo/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; export type Space = $ModelResult<$Schema, "Space">; export type SpaceUser = $ModelResult<$Schema, "SpaceUser">; export type User = $ModelResult<$Schema, "User">; diff --git a/tests/e2e/orm/schemas/typed-json/models.ts b/tests/e2e/orm/schemas/typed-json/models.ts index 2b2474faa..e46ea2c7b 100644 --- a/tests/e2e/orm/schemas/typed-json/models.ts +++ b/tests/e2e/orm/schemas/typed-json/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { schema as $schema, type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult, type TypeDefResult as $TypeDefResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult, TypeDefResult as $TypeDefResult } from "@zenstackhq/orm"; export type User = $ModelResult<$Schema, "User">; export type Profile = $TypeDefResult<$Schema, "Profile">; export type Address = $TypeDefResult<$Schema, "Address">; diff --git a/tests/e2e/orm/schemas/typing/models.ts b/tests/e2e/orm/schemas/typing/models.ts index 15eae9a94..20e775f4e 100644 --- a/tests/e2e/orm/schemas/typing/models.ts +++ b/tests/e2e/orm/schemas/typing/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { schema as $schema, type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult, type TypeDefResult as $TypeDefResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult, TypeDefResult as $TypeDefResult } from "@zenstackhq/orm"; export type User = $ModelResult<$Schema, "User">; export type Post = $ModelResult<$Schema, "Post">; export type Profile = $ModelResult<$Schema, "Profile">; diff --git a/tests/regression/test/issue-204/models.ts b/tests/regression/test/issue-204/models.ts index 633c334b2..4a284d844 100644 --- a/tests/regression/test/issue-204/models.ts +++ b/tests/regression/test/issue-204/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { schema as $schema, type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult, type TypeDefResult as $TypeDefResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult, TypeDefResult as $TypeDefResult } from "@zenstackhq/orm"; export type Foo = $ModelResult<$Schema, "Foo">; export type Configuration = $TypeDefResult<$Schema, "Configuration">; export const ShirtColor = $schema.enums.ShirtColor.values; diff --git a/tests/regression/test/issue-422/models.ts b/tests/regression/test/issue-422/models.ts index 787d1c77c..4255e3959 100644 --- a/tests/regression/test/issue-422/models.ts +++ b/tests/regression/test/issue-422/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; export type Session = $ModelResult<$Schema, "Session">; export type User = $ModelResult<$Schema, "User">; export type Profile = $ModelResult<$Schema, "Profile">; diff --git a/tests/regression/test/issue-503/models.ts b/tests/regression/test/issue-503/models.ts index e71f3a616..213a31f3a 100644 --- a/tests/regression/test/issue-503/models.ts +++ b/tests/regression/test/issue-503/models.ts @@ -6,7 +6,7 @@ /* eslint-disable */ import { type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; export type InternalChat = $ModelResult<$Schema, "InternalChat">; export type Message = $ModelResult<$Schema, "Message">; export type Media = $ModelResult<$Schema, "Media">; diff --git a/tests/runtimes/bun/schemas/models.ts b/tests/runtimes/bun/schemas/models.ts index 72654e587..03524da52 100644 --- a/tests/runtimes/bun/schemas/models.ts +++ b/tests/runtimes/bun/schemas/models.ts @@ -6,6 +6,6 @@ /* eslint-disable */ import { type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; export type User = $ModelResult<$Schema, "User">; export type Post = $ModelResult<$Schema, "Post">; diff --git a/tests/runtimes/edge-runtime/schemas/models.ts b/tests/runtimes/edge-runtime/schemas/models.ts index 72654e587..03524da52 100644 --- a/tests/runtimes/edge-runtime/schemas/models.ts +++ b/tests/runtimes/edge-runtime/schemas/models.ts @@ -6,6 +6,6 @@ /* eslint-disable */ import { type SchemaType as $Schema } from "./schema"; -import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; +import type { ModelResult as $ModelResult } from "@zenstackhq/orm"; export type User = $ModelResult<$Schema, "User">; export type Post = $ModelResult<$Schema, "Post">;