Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions sqlspec/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import threading
from abc import ABC, abstractmethod
from collections.abc import Callable
from inspect import Signature, signature
Expand Down Expand Up @@ -1510,7 +1512,7 @@ async def fix_migrations(self, dry_run: bool = False, update_database: bool = Tr
class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
"""Base class for sync database configurations with connection pooling."""

__slots__ = ("connection_config",)
__slots__ = ("_pool_lock", "connection_config")
is_async: "ClassVar[bool]" = False
supports_connection_pooling: "ClassVar[bool]" = True
migration_tracker_type: "ClassVar[type[Any]]" = SyncMigrationTracker
Expand Down Expand Up @@ -1549,6 +1551,7 @@ def __init__(
self.driver_features.setdefault("storage_capabilities", self.storage_capabilities())
self._promote_driver_feature_hooks()
self._configure_observability_extensions()
self._pool_lock = threading.Lock()

def create_pool(self) -> PoolT:
"""Create and return the connection pool.
Expand All @@ -1558,9 +1561,14 @@ def create_pool(self) -> PoolT:
"""
if self.connection_instance is not None:
return self.connection_instance
self.connection_instance = self._create_pool()
self.get_observability_runtime().emit_pool_create(self.connection_instance)
return self.connection_instance

with self._pool_lock:
if self.connection_instance is not None:
return self.connection_instance

self.connection_instance = self._create_pool()
self.get_observability_runtime().emit_pool_create(self.connection_instance)
return self.connection_instance

def close_pool(self) -> None:
"""Close the connection pool."""
Expand All @@ -1572,9 +1580,7 @@ def close_pool(self) -> None:

def provide_pool(self, *args: Any, **kwargs: Any) -> PoolT:
"""Provide pool instance."""
if self.connection_instance is None:
self.connection_instance = self.create_pool()
return self.connection_instance
return self.create_pool()

def create_connection(self) -> ConnectionT:
"""Create a database connection."""
Expand Down Expand Up @@ -1709,7 +1715,7 @@ def fix_migrations(self, dry_run: bool = False, update_database: bool = True, ye
class AsyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
"""Base class for async database configurations with connection pooling."""

__slots__ = ("connection_config",)
__slots__ = ("_pool_lock", "connection_config")
is_async: "ClassVar[bool]" = True
supports_connection_pooling: "ClassVar[bool]" = True
migration_tracker_type: "ClassVar[type[Any]]" = AsyncMigrationTracker
Expand Down Expand Up @@ -1750,6 +1756,7 @@ def __init__(
self.driver_features.setdefault("storage_capabilities", self.storage_capabilities())
self._promote_driver_feature_hooks()
self._configure_observability_extensions()
self._pool_lock = asyncio.Lock()

async def create_pool(self) -> PoolT:
"""Create and return the connection pool.
Expand All @@ -1759,9 +1766,14 @@ async def create_pool(self) -> PoolT:
"""
if self.connection_instance is not None:
return self.connection_instance
self.connection_instance = await self._create_pool()
self.get_observability_runtime().emit_pool_create(self.connection_instance)
return self.connection_instance

async with self._pool_lock:
if self.connection_instance is not None:
return self.connection_instance

self.connection_instance = await self._create_pool()
self.get_observability_runtime().emit_pool_create(self.connection_instance)
return self.connection_instance

async def close_pool(self) -> None:
"""Close the connection pool."""
Expand All @@ -1773,9 +1785,7 @@ async def close_pool(self) -> None:

async def provide_pool(self, *args: Any, **kwargs: Any) -> PoolT:
"""Provide pool instance."""
if self.connection_instance is None:
self.connection_instance = await self.create_pool()
return self.connection_instance
return await self.create_pool()

async def create_connection(self) -> ConnectionT:
"""Create a database connection."""
Expand Down
80 changes: 80 additions & 0 deletions tests/integration/test_pool_concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

import asyncio
import threading
from typing import TYPE_CHECKING

import pytest

from sqlspec.adapters.asyncpg import AsyncpgConfig
from sqlspec.adapters.duckdb import DuckDBConfig

if TYPE_CHECKING:
from pytest_databases.docker.postgres import PostgresService

from sqlspec.adapters.asyncpg import AsyncpgPool
from sqlspec.adapters.duckdb import DuckDBConnectionPool


@pytest.mark.asyncio
async def test_asyncpg_pool_concurrency(postgres_service: PostgresService) -> None:
"""Verify that multiple concurrent calls to provide_pool result in a single pool."""
config_params = {
"host": postgres_service.host,
"port": postgres_service.port,
"user": postgres_service.user,
"password": postgres_service.password,
"database": postgres_service.database,
}
# Initialize with connection_instance=None explicitly just to be sure
config = AsyncpgConfig(connection_config=config_params, connection_instance=None)

async def get_pool() -> AsyncpgPool:
# Artificial delay to ensure tasks overlap in checking connection_instance
# This simulates the "check" part of check-then-act overlapping
return await config.provide_pool()

# Launch many tasks simultaneously
tasks = [get_pool() for _ in range(50)]
pools = await asyncio.gather(*tasks)

# All pools should be the exact same object
first_pool = pools[0]
unique_pools = {id(p) for p in pools}

await config.close_pool()

assert len(unique_pools) == 1, f"Race condition detected! {len(unique_pools)} unique pools created."
assert all(p is first_pool for p in pools)


def test_duckdb_pool_concurrency() -> None:
"""Verify that multiple concurrent calls to provide_pool result in a single pool (Sync)."""
# Use shared memory db for valid concurrency test
config = DuckDBConfig(connection_config={"database": ":memory:"})

# We need to capture results from threads
results: list[DuckDBConnectionPool | None] = [None] * 50
exceptions: list[Exception] = []

def get_pool(index: int) -> None:
try:
pool = config.provide_pool()
results[index] = pool
except Exception as e:
exceptions.append(e)

threads = [threading.Thread(target=get_pool, args=(i,)) for i in range(50)]

for t in threads:
t.start()
for t in threads:
t.join()

if exceptions:
pytest.fail(f"Exceptions in threads: {exceptions}")

unique_pools = {id(p) for p in results if p is not None}
config.close_pool()

assert len(unique_pools) == 1, f"Race condition detected! {len(unique_pools)} unique DuckDB pools created."