Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
bin

# macOS
.DS_Store

# Devenv
.envrc
.direnv
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ version: "2"
plugins:
- name: py
wasm:
url: https://downloads.sqlc.dev/plugin/sqlc-gen-python_1.2.0.wasm
sha256: a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e
url: https://downloads.sqlc.dev/plugin/sqlc-gen-python_1.3.0.wasm
sha256: fbedae96b5ecae2380a70fb5b925fd4bff58a6cfb1f3140375d098fbab7b3a3c
sql:
- schema: "schema.sql"
queries: "query.sql"
Expand Down
13 changes: 13 additions & 0 deletions internal/endtoend/testdata/emit_numpy_array/python/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Code generated by sqlc. DO NOT EDIT.
# versions:
# sqlc v1.30.0
import dataclasses
import numpy
from numpy.typing import NDArray
from typing import Optional


@dataclasses.dataclass()
class Item:
id: int
embedding: Optional[NDArray[numpy.float32]]
68 changes: 68 additions & 0 deletions internal/endtoend/testdata/emit_numpy_array/python/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Code generated by sqlc. DO NOT EDIT.
# versions:
# sqlc v1.30.0
# source: query.sql
import numpy
from numpy.typing import NDArray
from typing import Optional

import sqlalchemy
import sqlalchemy.ext.asyncio

from python import models


CREATE_ITEM = """-- name: create_item \\:one
INSERT INTO items (embedding) VALUES (:p1) RETURNING id, embedding
"""


GET_ITEM = """-- name: get_item \\:one
SELECT id, embedding FROM items WHERE id = :p1
"""


class Querier:
def __init__(self, conn: sqlalchemy.engine.Connection):
self._conn = conn

def create_item(self, *, embedding: Optional[NDArray[numpy.float32]]) -> Optional[models.Item]:
row = self._conn.execute(sqlalchemy.text(CREATE_ITEM), {"p1": embedding}).first()
if row is None:
return None
return models.Item(
id=row[0],
embedding=row[1],
)

def get_item(self, *, id: int) -> Optional[models.Item]:
row = self._conn.execute(sqlalchemy.text(GET_ITEM), {"p1": id}).first()
if row is None:
return None
return models.Item(
id=row[0],
embedding=row[1],
)


class AsyncQuerier:
def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection):
self._conn = conn

async def create_item(self, *, embedding: Optional[NDArray[numpy.float32]]) -> Optional[models.Item]:
row = (await self._conn.execute(sqlalchemy.text(CREATE_ITEM), {"p1": embedding})).first()
if row is None:
return None
return models.Item(
id=row[0],
embedding=row[1],
)

async def get_item(self, *, id: int) -> Optional[models.Item]:
row = (await self._conn.execute(sqlalchemy.text(GET_ITEM), {"p1": id})).first()
if row is None:
return None
return models.Item(
id=row[0],
embedding=row[1],
)
6 changes: 6 additions & 0 deletions internal/endtoend/testdata/emit_numpy_array/query.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- name: GetItem :one
SELECT * FROM items WHERE id = $1;


-- name: CreateItem :one
INSERT INTO items (embedding) VALUES ($1) RETURNING *;
4 changes: 4 additions & 0 deletions internal/endtoend/testdata/emit_numpy_array/schema.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CREATE TABLE items (
id SERIAL PRIMARY KEY,
embedding vector(3)
);
17 changes: 17 additions & 0 deletions internal/endtoend/testdata/emit_numpy_array/sqlc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
version: '2'
plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "839af1f07c31644548192fc095569e62f1511d72c1c30c1a958ddc9c9429edbc"
sql:
- schema: schema.sql
queries: query.sql
engine: postgresql
codegen:
- plugin: py
out: python
options:
package: python
emit_sync_querier: true
emit_async_querier: true
4 changes: 4 additions & 0 deletions internal/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,5 +272,9 @@ func stdImports(uses func(name string) bool) map[string]importSpec {
if uses("Any") {
std["typing.Any"] = importSpec{Module: "typing", Name: "Any"}
}
if uses("NDArray[numpy.float32]") {
std["numpy"] = importSpec{Module: "numpy"}
std["numpy.typing.NDArray"] = importSpec{Module: "numpy.typing", Name: "NDArray"}
}
return std
}
2 changes: 2 additions & 0 deletions internal/postgresql_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ func postgresType(req *plugin.GenerateRequest, col *plugin.Column) string {
return "str"
case "ltree", "lquery", "ltxtquery":
return "str"
case "vector":
return "NDArray[numpy.float32]"
default:
for _, schema := range req.Catalog.Schemas {
if schema.Name == "pg_catalog" || schema.Name == "information_schema" {
Expand Down