diff --git a/sqlspec/config.py b/sqlspec/config.py index 0f8abfe1..5c23b9c7 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -1,3 +1,5 @@ +import asyncio +import threading from abc import ABC, abstractmethod from collections.abc import Callable from inspect import Signature, signature @@ -1510,7 +1512,7 @@ async def fix_migrations(self, dry_run: bool = False, update_database: bool = Tr class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]): """Base class for sync database configurations with connection pooling.""" - __slots__ = ("connection_config",) + __slots__ = ("_pool_lock", "connection_config") is_async: "ClassVar[bool]" = False supports_connection_pooling: "ClassVar[bool]" = True migration_tracker_type: "ClassVar[type[Any]]" = SyncMigrationTracker @@ -1549,6 +1551,7 @@ def __init__( self.driver_features.setdefault("storage_capabilities", self.storage_capabilities()) self._promote_driver_feature_hooks() self._configure_observability_extensions() + self._pool_lock = threading.Lock() def create_pool(self) -> PoolT: """Create and return the connection pool. @@ -1558,9 +1561,14 @@ def create_pool(self) -> PoolT: """ if self.connection_instance is not None: return self.connection_instance - self.connection_instance = self._create_pool() - self.get_observability_runtime().emit_pool_create(self.connection_instance) - return self.connection_instance + + with self._pool_lock: + if self.connection_instance is not None: + return self.connection_instance + + self.connection_instance = self._create_pool() + self.get_observability_runtime().emit_pool_create(self.connection_instance) + return self.connection_instance def close_pool(self) -> None: """Close the connection pool.""" @@ -1572,9 +1580,7 @@ def close_pool(self) -> None: def provide_pool(self, *args: Any, **kwargs: Any) -> PoolT: """Provide pool instance.""" - if self.connection_instance is None: - self.connection_instance = self.create_pool() - return self.connection_instance + return self.create_pool() def create_connection(self) -> ConnectionT: """Create a database connection.""" @@ -1709,7 +1715,7 @@ def fix_migrations(self, dry_run: bool = False, update_database: bool = True, ye class AsyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]): """Base class for async database configurations with connection pooling.""" - __slots__ = ("connection_config",) + __slots__ = ("_pool_lock", "connection_config") is_async: "ClassVar[bool]" = True supports_connection_pooling: "ClassVar[bool]" = True migration_tracker_type: "ClassVar[type[Any]]" = AsyncMigrationTracker @@ -1750,6 +1756,7 @@ def __init__( self.driver_features.setdefault("storage_capabilities", self.storage_capabilities()) self._promote_driver_feature_hooks() self._configure_observability_extensions() + self._pool_lock = asyncio.Lock() async def create_pool(self) -> PoolT: """Create and return the connection pool. @@ -1759,9 +1766,14 @@ async def create_pool(self) -> PoolT: """ if self.connection_instance is not None: return self.connection_instance - self.connection_instance = await self._create_pool() - self.get_observability_runtime().emit_pool_create(self.connection_instance) - return self.connection_instance + + async with self._pool_lock: + if self.connection_instance is not None: + return self.connection_instance + + self.connection_instance = await self._create_pool() + self.get_observability_runtime().emit_pool_create(self.connection_instance) + return self.connection_instance async def close_pool(self) -> None: """Close the connection pool.""" @@ -1773,9 +1785,7 @@ async def close_pool(self) -> None: async def provide_pool(self, *args: Any, **kwargs: Any) -> PoolT: """Provide pool instance.""" - if self.connection_instance is None: - self.connection_instance = await self.create_pool() - return self.connection_instance + return await self.create_pool() async def create_connection(self) -> ConnectionT: """Create a database connection.""" diff --git a/tests/integration/test_pool_concurrency.py b/tests/integration/test_pool_concurrency.py new file mode 100644 index 00000000..48d7016f --- /dev/null +++ b/tests/integration/test_pool_concurrency.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import asyncio +import threading +from typing import TYPE_CHECKING + +import pytest + +from sqlspec.adapters.asyncpg import AsyncpgConfig +from sqlspec.adapters.duckdb import DuckDBConfig + +if TYPE_CHECKING: + from pytest_databases.docker.postgres import PostgresService + + from sqlspec.adapters.asyncpg import AsyncpgPool + from sqlspec.adapters.duckdb import DuckDBConnectionPool + + +@pytest.mark.asyncio +async def test_asyncpg_pool_concurrency(postgres_service: PostgresService) -> None: + """Verify that multiple concurrent calls to provide_pool result in a single pool.""" + config_params = { + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + } + # Initialize with connection_instance=None explicitly just to be sure + config = AsyncpgConfig(connection_config=config_params, connection_instance=None) + + async def get_pool() -> AsyncpgPool: + # Artificial delay to ensure tasks overlap in checking connection_instance + # This simulates the "check" part of check-then-act overlapping + return await config.provide_pool() + + # Launch many tasks simultaneously + tasks = [get_pool() for _ in range(50)] + pools = await asyncio.gather(*tasks) + + # All pools should be the exact same object + first_pool = pools[0] + unique_pools = {id(p) for p in pools} + + await config.close_pool() + + assert len(unique_pools) == 1, f"Race condition detected! {len(unique_pools)} unique pools created." + assert all(p is first_pool for p in pools) + + +def test_duckdb_pool_concurrency() -> None: + """Verify that multiple concurrent calls to provide_pool result in a single pool (Sync).""" + # Use shared memory db for valid concurrency test + config = DuckDBConfig(connection_config={"database": ":memory:"}) + + # We need to capture results from threads + results: list[DuckDBConnectionPool | None] = [None] * 50 + exceptions: list[Exception] = [] + + def get_pool(index: int) -> None: + try: + pool = config.provide_pool() + results[index] = pool + except Exception as e: + exceptions.append(e) + + threads = [threading.Thread(target=get_pool, args=(i,)) for i in range(50)] + + for t in threads: + t.start() + for t in threads: + t.join() + + if exceptions: + pytest.fail(f"Exceptions in threads: {exceptions}") + + unique_pools = {id(p) for p in results if p is not None} + config.close_pool() + + assert len(unique_pools) == 1, f"Race condition detected! {len(unique_pools)} unique DuckDB pools created."