diff --git a/.github/workflows/ci-test.yml b/.github/workflows/ci-test.yml index 2ee494c5..a11cf75f 100644 --- a/.github/workflows/ci-test.yml +++ b/.github/workflows/ci-test.yml @@ -23,9 +23,9 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.10", "3.11", "3.12", "3.13"] os: ["ubuntu-latest", "macOS-latest", "windows-latest"] - backend: ["local", "mongodb", "postgres", "redis"] + backend: ["local", "mongodb", "postgres", "redis", "s3"] exclude: # ToDo: take if back when the connection become stable # or resolve using `InMemoryMongoClient` @@ -65,7 +65,7 @@ jobs: - name: Unit tests (local) if: matrix.backend == 'local' - run: pytest -m "not mongo and not sql and not redis" --cov=cachier --cov-report=term --cov-report=xml:cov.xml + run: pytest -m "not mongo and not sql and not redis and not s3" --cov=cachier --cov-report=term --cov-report=xml:cov.xml - name: Setup docker (missing on MacOS) if: runner.os == 'macOS' && matrix.backend == 'mongodb' @@ -135,6 +135,10 @@ jobs: if: matrix.backend == 'redis' run: pytest -m redis --cov=cachier --cov-report=term --cov-report=xml:cov.xml + - name: Unit tests (S3) + if: matrix.backend == 's3' + run: pytest -m s3 --cov=cachier --cov-report=term --cov-report=xml:cov.xml + - name: Upload coverage to Codecov (non PRs) continue-on-error: true uses: codecov/codecov-action@v5 diff --git a/README.rst b/README.rst index 55e38286..fca79622 100644 --- a/README.rst +++ b/README.rst @@ -59,6 +59,7 @@ Current features * Cross-machine caching using MongoDB. * SQL-based caching using SQLAlchemy-supported databases. * Redis-based caching for high-performance scenarios. + * S3-based caching for cross-machine object storage backends. * Thread-safety. * **Per-call max age:** Specify a maximum age for cached values per call. @@ -71,7 +72,6 @@ Cachier is **NOT**: Future features --------------- -* S3 core. * Multi-core caching. * `Cache replacement policies `_ @@ -580,6 +580,12 @@ Cachier supports Redis-based caching for high-performance scenarios. Redis provi - ``processing``: Boolean, is value being calculated - ``completed``: Boolean, is value calculation completed +**S3 Sync/Async Support:** + +- Sync functions use direct boto3 calls. +- Async functions are supported via thread-offloaded sync boto3 calls + (delegated mode), not a native async client. + **Limitations & Notes:** - Requires SQLAlchemy (install with ``pip install SQLAlchemy``) @@ -631,6 +637,11 @@ async drivers and require the client or engine type to match the decorated funct - ``redis_client`` must be a sync client or sync callable for sync functions and an async callable returning a ``redis.asyncio.Redis`` client for async functions. Passing a sync callable to an async function raises ``TypeError``. + * - **S3** + - Yes + - Yes (delegated) + - Async support is delegated via thread-offloaded sync boto3 calls + (``asyncio.to_thread``). No async S3 client is required. Contributing @@ -655,13 +666,14 @@ Install in development mode with test dependencies for local cores (memory and p cd cachier pip install -e . -r tests/requirements.txt -Each additional core (MongoDB, Redis, SQL) requires additional dependencies. To install all dependencies for all cores, run: +Each additional core (MongoDB, Redis, SQL, S3) requires additional dependencies. To install all dependencies for all cores, run: .. code-block:: bash pip install -r tests/requirements_mongodb.txt pip install -r tests/requirements_redis.txt pip install -r tests/requirements_postgres.txt + pip install -r tests/requirements_s3.txt Running the tests ----------------- @@ -724,7 +736,7 @@ This script automatically handles Docker container lifecycle, environment variab .. code-block:: bash make test-mongo-local # Run MongoDB tests with Docker - make test-all-local # Run all backends with Docker + make test-all-local # Run all backends locally (Docker used for mongo/redis/sql) make test-mongo-inmemory # Run with in-memory MongoDB (default) **Option 3: Manual setup** @@ -750,18 +762,21 @@ Contributors are encouraged to test against a real MongoDB instance before submi Testing all backends locally ----------------------------- -To test all cachier backends (MongoDB, Redis, SQL, Memory, Pickle) locally with Docker: +To test all cachier backends (MongoDB, Redis, SQL, S3, Memory, Pickle) locally: .. code-block:: bash # Test all backends at once ./scripts/test-local.sh all - # Test only external backends (MongoDB, Redis, SQL) + # Test only external backends that require Docker (MongoDB, Redis, SQL) ./scripts/test-local.sh external + # Test S3 backend only (uses moto, no Docker needed) + ./scripts/test-local.sh s3 + # Test specific combinations - ./scripts/test-local.sh mongo redis + ./scripts/test-local.sh mongo redis s3 # Keep containers running for debugging ./scripts/test-local.sh all -k @@ -772,7 +787,7 @@ To test all cachier backends (MongoDB, Redis, SQL, Memory, Pickle) locally with # Test multiple files across all backends ./scripts/test-local.sh all -f tests/test_main.py -f tests/test_redis_core_coverage.py -The unified test script automatically manages Docker containers, installs required dependencies, and runs the appropriate test suites. The ``-f`` / ``--files`` option allows you to run specific test files instead of the entire test suite. See ``scripts/README-local-testing.md`` for detailed documentation. +The unified test script automatically manages Docker containers for MongoDB/Redis/SQL, installs required dependencies (including ``tests/requirements_s3.txt`` for S3), and runs the appropriate test suites. The ``-f`` / ``--files`` option allows you to run specific test files instead of the entire test suite. See ``scripts/README-local-testing.md`` for detailed documentation. Running pre-commit hooks locally diff --git a/examples/s3_example.py b/examples/s3_example.py new file mode 100644 index 00000000..b6bbb5f7 --- /dev/null +++ b/examples/s3_example.py @@ -0,0 +1,208 @@ +"""Cachier S3 backend example. + +Demonstrates persistent function caching backed by AWS S3 (or any S3-compatible +service). Requires boto3 to be installed:: + + pip install cachier[s3] + +A real S3 bucket (or a local S3-compatible service such as MinIO / localstack) +is needed to run this example. Adjust the configuration variables below to +match your environment. + +""" + +import time +from datetime import timedelta + +try: + import boto3 + + from cachier import cachier +except ImportError as exc: + print(f"Missing required package: {exc}") + print("Install with: pip install cachier[s3]") + raise SystemExit(1) from exc + +# --------------------------------------------------------------------------- +# Configuration - adjust these to your environment +# --------------------------------------------------------------------------- +BUCKET_NAME = "my-cachier-bucket" +REGION = "us-east-1" + +# Optional: point to a local S3-compatible service +# ENDPOINT_URL = "http://localhost:9000" # MinIO default +ENDPOINT_URL = None + + +# --------------------------------------------------------------------------- +# Helper: verify S3 connectivity +# --------------------------------------------------------------------------- + + +def _check_bucket(client, bucket: str) -> bool: + """Return True if the bucket is accessible.""" + try: + client.head_bucket(Bucket=bucket) + return True + except Exception as exc: + print(f"Cannot access bucket '{bucket}': {exc}") + return False + + +# --------------------------------------------------------------------------- +# Demos +# --------------------------------------------------------------------------- + + +def demo_basic_caching(): + """Show basic S3 caching: the first call computes, the second reads cache.""" + print("\n=== Basic S3 caching ===") + + @cachier( + backend="s3", + s3_bucket=BUCKET_NAME, + s3_region=REGION, + s3_endpoint_url=ENDPOINT_URL, + ) + def expensive(n: int) -> int: + """Simulate an expensive computation.""" + print(f" computing expensive({n})...") + time.sleep(1) + return n * n + + expensive.clear_cache() + + start = time.time() + r1 = expensive(5) + t1 = time.time() - start + print(f"First call: {r1} ({t1:.2f}s)") + + start = time.time() + r2 = expensive(5) + t2 = time.time() - start + print(f"Second call: {r2} ({t2:.2f}s) - from cache") + + assert r1 == r2 + assert t2 < t1 + print("Basic caching works correctly.") + + +def demo_stale_after(): + """Show stale_after: results expire and are recomputed after the timeout.""" + print("\n=== Stale-after demo ===") + + @cachier( + backend="s3", + s3_bucket=BUCKET_NAME, + s3_region=REGION, + s3_endpoint_url=ENDPOINT_URL, + stale_after=timedelta(seconds=3), + ) + def timed(n: int) -> float: + print(f" computing timed({n})...") + return time.time() + + timed.clear_cache() + r1 = timed(1) + r2 = timed(1) + assert r1 == r2, "Second call should hit cache" + + print("Sleeping 4 seconds so the entry becomes stale...") + time.sleep(4) + + r3 = timed(1) + assert r3 > r1, "Should have recomputed after stale period" + print("Stale-after works correctly.") + + +def demo_client_factory(): + """Show using a callable factory instead of a pre-built client.""" + print("\n=== Client factory demo ===") + + def make_client(): + """Lazily create a boto3 S3 client.""" + kwargs = {"region_name": REGION} + if ENDPOINT_URL: + kwargs["endpoint_url"] = ENDPOINT_URL + return boto3.client("s3", **kwargs) + + @cachier( + backend="s3", + s3_bucket=BUCKET_NAME, + s3_client_factory=make_client, + ) + def compute(n: int) -> int: + return n + 100 + + compute.clear_cache() + assert compute(7) == compute(7) + print("Client factory works correctly.") + + +def demo_cache_management(): + """Show clear_cache and overwrite_cache.""" + print("\n=== Cache management demo ===") + call_count = [0] + + @cachier( + backend="s3", + s3_bucket=BUCKET_NAME, + s3_region=REGION, + s3_endpoint_url=ENDPOINT_URL, + ) + def managed(n: int) -> int: + call_count[0] += 1 + return n * 3 + + managed.clear_cache() + managed(10) + managed(10) + assert call_count[0] == 1, "Should have been called once (cached on second call)" + + managed.clear_cache() + managed(10) + assert call_count[0] == 2, "Should have recomputed after cache clear" + + managed(10, cachier__overwrite_cache=True) + assert call_count[0] == 3, "Should have recomputed due to overwrite_cache" + print("Cache management works correctly.") + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def main(): + """Run all S3 backend demos.""" + print("Cachier S3 Backend Demo") + print("=" * 50) + + client = boto3.client( + "s3", + region_name=REGION, + **({"endpoint_url": ENDPOINT_URL} if ENDPOINT_URL else {}), + ) + + if not _check_bucket(client, BUCKET_NAME): + print(f"\nCreate the bucket first: aws s3 mb s3://{BUCKET_NAME} --region {REGION}") + raise SystemExit(1) + + try: + demo_basic_caching() + demo_stale_after() + demo_client_factory() + demo_cache_management() + + print("\n" + "=" * 50) + print("All S3 demos completed successfully.") + print("\nKey benefits of the S3 backend:") + print("- Persistent cache survives process restarts") + print("- Shared across machines without a running service") + print("- Works with any S3-compatible object storage") + finally: + client.close() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index b3565e8e..cb81b635 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,25 @@ dependencies = [ "pympler>=1", "watchdog>=2.3.1", ] + +optional-dependencies.all = [ + "boto3>=1.26", + "pymongo>=4", + "redis>=4", + "sqlalchemy>=2", +] +optional-dependencies.mongo = [ + "pymongo>=4", +] +optional-dependencies.redis = [ + "redis>=4", +] +optional-dependencies.s3 = [ + "boto3>=1.26", +] +optional-dependencies.sql = [ + "sqlalchemy>=2", +] urls.Source = "https://github.com/python-cachier/cachier" # --- setuptools --- @@ -177,6 +196,7 @@ markers = [ "pickle: test the pickle core", "redis: test the Redis core", "sql: test the SQL core", + "s3: test the S3 core", "maxage: test the max_age functionality", "asyncio: marks tests as async", ] diff --git a/scripts/test-local.sh b/scripts/test-local.sh index 0d07661b..e7efc571 100755 --- a/scripts/test-local.sh +++ b/scripts/test-local.sh @@ -45,9 +45,10 @@ CORES: mongo MongoDB backend tests redis Redis backend tests sql SQL (PostgreSQL) backend tests + s3 S3 backend tests (no Docker needed) memory Memory backend tests (no Docker needed) pickle Pickle backend tests (no Docker needed) - all All backends (equivalent to: mongo redis sql memory pickle) + all All backends (equivalent to: mongo redis sql s3 memory pickle) external All external backends (mongo redis sql) local All local backends (memory pickle) @@ -133,7 +134,7 @@ expand_cores() { for core in $1; do case $core in all) - cores="$cores mongo redis sql memory pickle" + cores="$cores mongo redis sql s3 memory pickle" ;; external) cores="$cores mongo redis sql" @@ -158,6 +159,7 @@ get_markers_for_core() { mongo) echo "mongo" ;; redis) echo "redis" ;; sql) echo "sql" ;; + s3) echo "s3" ;; memory) echo "memory" ;; pickle) echo "pickle or maxage" ;; *) echo "$1" ;; # Default to core name @@ -166,11 +168,11 @@ get_markers_for_core() { # Validate cores validate_cores() { - local valid_cores="mongo redis sql memory pickle" + local valid_cores="mongo redis sql s3 memory pickle" for core in $1; do if ! echo "$valid_cores" | grep -qw "$core"; then print_message $RED "Error: Invalid core '$core'" - print_message $YELLOW "Valid cores: mongo, redis, sql, memory, pickle" + print_message $YELLOW "Valid cores: mongo, redis, sql, s3, memory, pickle" exit 1 fi done @@ -265,6 +267,17 @@ check_dependencies() { fi fi + # Check S3 dependencies if testing S3 + if echo "$SELECTED_CORES" | grep -qw "s3"; then + if ! python -c "import boto3; import moto" 2>/dev/null; then + print_message $YELLOW "Installing S3 test requirements..." + pip install -r tests/requirements_s3.txt || { + print_message $RED "Failed to install S3 requirements" + exit 1 + } + fi + fi + print_message $GREEN "All required dependencies are installed!" } @@ -498,7 +511,7 @@ main() { # Add markers if needed (only if no specific test files were given) if [ -z "$TEST_FILES" ]; then # Check if we selected all cores - if so, run all tests without marker filtering - all_cores="memory mongo pickle redis sql" + all_cores="memory mongo pickle redis s3 sql" selected_sorted=$(echo "$SELECTED_CORES" | tr ' ' '\n' | sort | tr '\n' ' ' | xargs) all_sorted=$(echo "$all_cores" | tr ' ' '\n' | sort | tr '\n' ' ' | xargs) @@ -507,7 +520,7 @@ main() { fi else # When test files are specified, still apply markers if not running all cores - all_cores="memory mongo pickle redis sql" + all_cores="memory mongo pickle redis s3 sql" selected_sorted=$(echo "$SELECTED_CORES" | tr ' ' '\n' | sort | tr '\n' ' ' | xargs) all_sorted=$(echo "$all_cores" | tr ' ' '\n' | sort | tr '\n' ' ' | xargs) diff --git a/src/cachier/_types.py b/src/cachier/_types.py index a4a814f0..01486069 100644 --- a/src/cachier/_types.py +++ b/src/cachier/_types.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Awaitable, Callable, Literal, Union +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Literal, Union if TYPE_CHECKING: import pymongo.collection @@ -8,4 +8,5 @@ HashFunc = Callable[..., str] Mongetter = Callable[[], Union["pymongo.collection.Collection", Awaitable["pymongo.collection.Collection"]]] RedisClient = Union["redis.Redis", Callable[[], Union["redis.Redis", Awaitable["redis.Redis"]]]] -Backend = Literal["pickle", "mongo", "memory", "redis"] +S3Client = Union[Any, Callable[[], Any]] +Backend = Literal["pickle", "mongo", "memory", "redis", "s3"] diff --git a/src/cachier/core.py b/src/cachier/core.py index f0db36e8..7c350140 100644 --- a/src/cachier/core.py +++ b/src/cachier/core.py @@ -19,13 +19,14 @@ from typing import Any, Callable, Optional, Union from warnings import warn -from ._types import RedisClient +from ._types import RedisClient, S3Client from .config import Backend, HashFunc, Mongetter, _update_with_defaults from .cores.base import RecalculationNeeded, _BaseCore from .cores.memory import _MemoryCore from .cores.mongo import _MongoCore from .cores.pickle import _PickleCore from .cores.redis import _RedisCore +from .cores.s3 import _S3Core from .cores.sql import _SQLCore from .util import parse_bytes @@ -34,6 +35,18 @@ ZERO_TIMEDELTA = timedelta(seconds=0) +class _ImmediateAwaitable: + """Lightweight awaitable that yields an immediate value.""" + + def __init__(self, value: Any = None) -> None: + self._value = value + + def __await__(self): + if False: + yield None + return self._value + + def _max_workers(): return int(os.environ.get(MAX_WORKERS_ENVAR_NAME, DEFAULT_MAX_WORKERS)) @@ -170,6 +183,13 @@ def cachier( mongetter: Optional[Mongetter] = None, sql_engine: Optional[Union[str, Any, Callable[[], Any]]] = None, redis_client: Optional["RedisClient"] = None, + s3_bucket: Optional[str] = None, + s3_prefix: str = "cachier", + s3_client: Optional["S3Client"] = None, + s3_client_factory: Optional[Callable[[], Any]] = None, + s3_region: Optional[str] = None, + s3_endpoint_url: Optional[str] = None, + s3_config: Optional[Any] = None, stale_after: Optional[timedelta] = None, next_time: Optional[bool] = None, cache_dir: Optional[Union[str, os.PathLike]] = None, @@ -201,9 +221,9 @@ def cachier( Deprecated, use :func:`~cachier.core.cachier.hash_func` instead. backend : str, optional The name of the backend to use. Valid options currently include - 'pickle', 'mongo', 'memory', 'sql', and 'redis'. If not provided, - defaults to 'pickle', unless a core-associated parameter is provided - + 'pickle', 'mongo', 'memory', 'sql', 'redis', and 's3'. If not + provided, defaults to 'pickle', unless a core-associated parameter + is provided. mongetter : callable, optional A callable that takes no arguments and returns a pymongo.Collection object with writing permissions. If provided, the backend is set to @@ -214,6 +234,20 @@ def cachier( redis_client : redis.Redis or callable, optional Redis client instance or callable returning a Redis client. Used for the Redis backend. + s3_bucket : str, optional + The S3 bucket name for cache storage. Required when using the S3 backend. + s3_prefix : str, optional + Key prefix applied to all S3 cache objects. Defaults to ``"cachier"``. + s3_client : boto3 S3 client, optional + A pre-configured boto3 S3 client instance. + s3_client_factory : callable, optional + A callable that returns a boto3 S3 client, allowing lazy initialization. + s3_region : str, optional + AWS region name used when auto-creating the boto3 S3 client. + s3_endpoint_url : str, optional + Custom endpoint URL for S3-compatible services such as MinIO or localstack. + s3_config : botocore.config.Config, optional + Optional botocore Config object passed when auto-creating the client. stale_after : datetime.timedelta, optional The time delta after which a cached result is considered stale. Calls made after the result goes stale will trigger a recalculation of the @@ -302,6 +336,19 @@ def cachier( wait_for_calc_timeout=wait_for_calc_timeout, entry_size_limit=size_limit_bytes, ) + elif backend == "s3": + core = _S3Core( + hash_func=hash_func, + s3_bucket=s3_bucket, + wait_for_calc_timeout=wait_for_calc_timeout, + s3_prefix=s3_prefix, + s3_client=s3_client, + s3_client_factory=s3_client_factory, + s3_region=s3_region, + s3_endpoint_url=s3_endpoint_url, + s3_config=s3_config, + entry_size_limit=size_limit_bytes, + ) else: raise ValueError("specified an invalid core: %s" % backend) @@ -558,10 +605,16 @@ def func_wrapper(*args, **kwargs): def _clear_cache(): """Clear the cache.""" core.clear_cache() + if is_coroutine: + return _ImmediateAwaitable() + return None def _clear_being_calculated(): """Mark all entries in this cache as not being calculated.""" core.clear_being_calculated() + if is_coroutine: + return _ImmediateAwaitable() + return None async def _aclear_cache(): """Clear the cache asynchronously.""" diff --git a/src/cachier/cores/s3.py b/src/cachier/cores/s3.py new file mode 100644 index 00000000..239612ec --- /dev/null +++ b/src/cachier/cores/s3.py @@ -0,0 +1,428 @@ +"""An S3-based caching core for cachier.""" + +import asyncio +import contextlib +import pickle +import time +import warnings +from datetime import datetime, timedelta +from typing import Any, Callable, Optional, Tuple + +try: + import boto3 # type: ignore[import-untyped] + import botocore.exceptions # type: ignore[import-untyped] + + BOTO3_AVAILABLE = True +except ImportError: + BOTO3_AVAILABLE = False + +from .._types import HashFunc +from ..config import CacheEntry +from .base import RecalculationNeeded, _BaseCore, _get_func_str + +S3_SLEEP_DURATION_IN_SEC = 1 + + +class MissingS3Bucket(ValueError): + """Thrown when the s3_bucket keyword argument is missing.""" + + +def _safe_warn(message: str, category: type[Warning] = UserWarning) -> None: + """Emit a warning without raising when warnings are configured as errors.""" + with contextlib.suppress(Warning): + warnings.warn(message, category, stacklevel=2) + + +class _S3Core(_BaseCore): + """S3-based core for Cachier, supporting AWS S3 and S3-compatible backends. + + Async support in this core is delegated rather than native. Since boto3 is a + synchronous client, async methods explicitly offload I/O work to a thread via + ``asyncio.to_thread``. + + Parameters + ---------- + hash_func : callable, optional + A callable to hash function arguments into a cache key string. + s3_bucket : str + The name of the S3 bucket to use for caching. + wait_for_calc_timeout : int, optional + Maximum seconds to wait for a concurrent calculation. 0 means wait forever. + s3_prefix : str, optional + Key prefix for all cache entries. Defaults to ``"cachier"``. + s3_client : boto3 S3 client, optional + A pre-configured boto3 S3 client instance. + s3_client_factory : callable, optional + A callable that returns a boto3 S3 client. Allows lazy initialization. + s3_region : str, optional + AWS region name used when creating the S3 client. + s3_endpoint_url : str, optional + Custom endpoint URL for S3-compatible services (e.g. MinIO, localstack). + s3_config : boto3 Config, optional + Optional ``botocore.config.Config`` object passed when creating the client. + entry_size_limit : int, optional + Maximum allowed size in bytes of a cached value. + + """ + + def __init__( + self, + hash_func: Optional[HashFunc], + s3_bucket: Optional[str], + wait_for_calc_timeout: Optional[int] = None, + s3_prefix: str = "cachier", + s3_client: Optional[Any] = None, + s3_client_factory: Optional[Callable[[], Any]] = None, + s3_region: Optional[str] = None, + s3_endpoint_url: Optional[str] = None, + s3_config: Optional[Any] = None, + entry_size_limit: Optional[int] = None, + ): + if not BOTO3_AVAILABLE: + _safe_warn( + "`boto3` was not found. S3 cores will not function. Install with `pip install boto3`.", + ImportWarning, + ) + + super().__init__( + hash_func=hash_func, + wait_for_calc_timeout=wait_for_calc_timeout, + entry_size_limit=entry_size_limit, + ) + + if not s3_bucket: + raise MissingS3Bucket("must specify ``s3_bucket`` when using the s3 core") + + self.s3_bucket = s3_bucket + self.s3_prefix = s3_prefix + self._s3_client = s3_client + self._s3_client_factory = s3_client_factory + self._s3_region = s3_region + self._s3_endpoint_url = s3_endpoint_url + self._s3_config = s3_config + self._func_str: Optional[str] = None + + def set_func(self, func: Callable) -> None: + """Set the function this core will use.""" + super().set_func(func) + self._func_str = _get_func_str(func) + + def _get_s3_client(self) -> Any: + """Return a boto3 S3 client, creating one if not already available.""" + if self._s3_client_factory is not None: + return self._s3_client_factory() + if self._s3_client is not None: + return self._s3_client + kwargs: dict = {} + if self._s3_region: + kwargs["region_name"] = self._s3_region + if self._s3_endpoint_url: + kwargs["endpoint_url"] = self._s3_endpoint_url + if self._s3_config: + kwargs["config"] = self._s3_config + self._s3_client = boto3.client("s3", **kwargs) + return self._s3_client + + def _get_s3_key(self, key: str) -> str: + """Return the full S3 object key for the given cache key.""" + return f"{self.s3_prefix}/{self._func_str}/{key}.pkl" + + def _get_s3_prefix(self) -> str: + """Return the S3 prefix for all objects belonging to this function.""" + return f"{self.s3_prefix}/{self._func_str}/" + + def _load_entry(self, body: bytes) -> Optional[CacheEntry]: + """Deserialize raw S3 object bytes into a CacheEntry.""" + try: + data = pickle.loads(body) + except Exception as exc: + _safe_warn(f"S3 cache entry deserialization failed: {exc}") + return None + + try: + raw_time = data.get("time", datetime.now()) + entry_time = datetime.fromisoformat(raw_time) if isinstance(raw_time, str) else raw_time + + return CacheEntry( + value=data.get("value"), + time=entry_time, + stale=bool(data.get("stale", False)), + _processing=bool(data.get("_processing", False)), + _completed=bool(data.get("_completed", False)), + ) + except Exception as exc: + _safe_warn(f"S3 CacheEntry construction failed: {exc}") + return None + + def _dump_entry(self, entry: CacheEntry) -> bytes: + """Serialize a CacheEntry to bytes for S3 storage.""" + data = { + "value": entry.value, + "time": entry.time.isoformat(), + "stale": entry.stale, + "_processing": entry._processing, + "_completed": entry._completed, + } + return pickle.dumps(data) + + # ------------------------------------------------------------------ + # Core interface + # ------------------------------------------------------------------ + + def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]: + """Get a cache entry from S3 by its key. + + Parameters + ---------- + key : str + The cache key string. + + Returns + ------- + tuple + A ``(key, CacheEntry)`` pair, or ``(key, None)`` if not found. + + """ + s3_key = self._get_s3_key(key) + client = self._get_s3_client() + try: + response = client.get_object(Bucket=self.s3_bucket, Key=s3_key) + body = response["Body"].read() + entry = self._load_entry(body) + return key, entry + except botocore.exceptions.ClientError as exc: + if exc.response["Error"]["Code"] in ("NoSuchKey", "404"): + return key, None + _safe_warn(f"S3 get_entry_by_key failed: {exc}") + return key, None + except Exception as exc: + _safe_warn(f"S3 get_entry_by_key failed: {exc}") + return key, None + + def set_entry(self, key: str, func_res: Any) -> bool: + """Store a function result in S3 under the given key. + + Parameters + ---------- + key : str + The cache key string. + func_res : any + The function result to cache. + + Returns + ------- + bool + ``True`` if the entry was stored, ``False`` otherwise. + + """ + if not self._should_store(func_res): + return False + s3_key = self._get_s3_key(key) + client = self._get_s3_client() + entry = CacheEntry( + value=func_res, + time=datetime.now(), + stale=False, + _processing=False, + _completed=True, + ) + try: + client.put_object(Bucket=self.s3_bucket, Key=s3_key, Body=self._dump_entry(entry)) + return True + except Exception as exc: + _safe_warn(f"S3 set_entry failed: {exc}") + return False + + def mark_entry_being_calculated(self, key: str) -> None: + """Mark the given cache entry as currently being calculated. + + Parameters + ---------- + key : str + The cache key string. + + """ + s3_key = self._get_s3_key(key) + client = self._get_s3_client() + entry = CacheEntry( + value=None, + time=datetime.now(), + stale=False, + _processing=True, + _completed=False, + ) + try: + client.put_object(Bucket=self.s3_bucket, Key=s3_key, Body=self._dump_entry(entry)) + except Exception as exc: + _safe_warn(f"S3 mark_entry_being_calculated failed: {exc}") + + def mark_entry_not_calculated(self, key: str) -> None: + """Mark the given cache entry as no longer being calculated. + + Parameters + ---------- + key : str + The cache key string. + + """ + s3_key = self._get_s3_key(key) + client = self._get_s3_client() + try: + response = client.get_object(Bucket=self.s3_bucket, Key=s3_key) + body = response["Body"].read() + entry = self._load_entry(body) + if entry is not None: + entry._processing = False + client.put_object(Bucket=self.s3_bucket, Key=s3_key, Body=self._dump_entry(entry)) + except botocore.exceptions.ClientError as exc: + if exc.response["Error"]["Code"] not in ("NoSuchKey", "404"): + _safe_warn(f"S3 mark_entry_not_calculated failed: {exc}") + except Exception as exc: + _safe_warn(f"S3 mark_entry_not_calculated failed: {exc}") + + def wait_on_entry_calc(self, key: str) -> Any: + """Poll S3 until the entry is no longer being calculated, then return its value. + + Parameters + ---------- + key : str + The cache key string. + + Returns + ------- + any + The cached value once calculation is complete. + + """ + time_spent = 0 + while True: + time.sleep(S3_SLEEP_DURATION_IN_SEC) + time_spent += S3_SLEEP_DURATION_IN_SEC + _, entry = self.get_entry_by_key(key) + if entry is None: + raise RecalculationNeeded() + if not entry._processing: + return entry.value + self.check_calc_timeout(time_spent) + + def clear_cache(self) -> None: + """Delete all cache entries for this function from S3.""" + client = self._get_s3_client() + prefix = self._get_s3_prefix() + try: + paginator = client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=self.s3_bucket, Prefix=prefix) + objects_to_delete = [] + for page in pages: + for obj in page.get("Contents", []): + objects_to_delete.append({"Key": obj["Key"]}) + if objects_to_delete: + # S3 delete_objects accepts up to 1000 keys per request + for i in range(0, len(objects_to_delete), 1000): + client.delete_objects( + Bucket=self.s3_bucket, + Delete={"Objects": objects_to_delete[i : i + 1000]}, + ) + except Exception as exc: + _safe_warn(f"S3 clear_cache failed: {exc}") + + def clear_being_calculated(self) -> None: + """Reset the ``_processing`` flag on all entries for this function in S3.""" + client = self._get_s3_client() + prefix = self._get_s3_prefix() + try: + paginator = client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=self.s3_bucket, Prefix=prefix) + for page in pages: + for obj in page.get("Contents", []): + s3_key = obj["Key"] + try: + response = client.get_object(Bucket=self.s3_bucket, Key=s3_key) + body = response["Body"].read() + entry = self._load_entry(body) + if entry is not None and entry._processing: + entry._processing = False + client.put_object(Bucket=self.s3_bucket, Key=s3_key, Body=self._dump_entry(entry)) + except Exception as exc: + _safe_warn(f"S3 clear_being_calculated entry update failed: {exc}") + except Exception as exc: + _safe_warn(f"S3 clear_being_calculated failed: {exc}") + + def delete_stale_entries(self, stale_after: timedelta) -> None: + """Remove cache entries older than ``stale_after`` from S3. + + Parameters + ---------- + stale_after : datetime.timedelta + Entries older than this duration will be deleted. + + """ + client = self._get_s3_client() + prefix = self._get_s3_prefix() + threshold = datetime.now() - stale_after + try: + paginator = client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=self.s3_bucket, Prefix=prefix) + for page in pages: + for obj in page.get("Contents", []): + s3_key = obj["Key"] + try: + response = client.get_object(Bucket=self.s3_bucket, Key=s3_key) + body = response["Body"].read() + entry = self._load_entry(body) + if entry is not None and entry.time < threshold: + client.delete_object(Bucket=self.s3_bucket, Key=s3_key) + except Exception as exc: + _safe_warn(f"S3 delete_stale_entries entry check failed: {exc}") + except Exception as exc: + _safe_warn(f"S3 delete_stale_entries failed: {exc}") + + # ------------------------------------------------------------------ + # Async variants explicitly offload sync boto3 operations to avoid + # blocking the event loop thread. + # ------------------------------------------------------------------ + + async def aget_entry(self, args, kwds) -> Tuple[str, Optional[CacheEntry]]: + """Async-compatible variant of :meth:`get_entry`. + + This method delegates to the sync implementation via + ``asyncio.to_thread`` because boto3 is sync-only. + + """ + return await asyncio.to_thread(self.get_entry, args, kwds) + + async def aget_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]: + """Async-compatible variant of :meth:`get_entry_by_key`. + + This method delegates to the sync implementation via + ``asyncio.to_thread`` because boto3 is sync-only. + + """ + return await asyncio.to_thread(self.get_entry_by_key, key) + + async def aset_entry(self, key: str, func_res: Any) -> bool: + """Async-compatible variant of :meth:`set_entry`. + + This method delegates to the sync implementation via + ``asyncio.to_thread`` because boto3 is sync-only. + + """ + return await asyncio.to_thread(self.set_entry, key, func_res) + + async def amark_entry_being_calculated(self, key: str) -> None: + """Async-compatible variant of :meth:`mark_entry_being_calculated`. + + This method delegates to the sync implementation via + ``asyncio.to_thread`` because boto3 is sync-only. + + """ + await asyncio.to_thread(self.mark_entry_being_calculated, key) + + async def amark_entry_not_calculated(self, key: str) -> None: + """Async-compatible variant of :meth:`mark_entry_not_calculated`. + + This method delegates to the sync implementation via + ``asyncio.to_thread`` because boto3 is sync-only. + + """ + await asyncio.to_thread(self.mark_entry_not_calculated, key) diff --git a/tests/requirements_s3.txt b/tests/requirements_s3.txt new file mode 100644 index 00000000..e230f4aa --- /dev/null +++ b/tests/requirements_s3.txt @@ -0,0 +1,4 @@ +-r requirements.txt + +boto3>=1.26.0 +moto[s3]>=5.0.0 diff --git a/tests/s3_tests/__init__.py b/tests/s3_tests/__init__.py new file mode 100644 index 00000000..f49b8531 --- /dev/null +++ b/tests/s3_tests/__init__.py @@ -0,0 +1 @@ +"""S3 backend tests for cachier.""" diff --git a/tests/s3_tests/conftest.py b/tests/s3_tests/conftest.py new file mode 100644 index 00000000..925c84ef --- /dev/null +++ b/tests/s3_tests/conftest.py @@ -0,0 +1,27 @@ +"""Shared S3 test fixtures.""" + +import pytest + +from .helpers import S3_DEPS_AVAILABLE, TEST_BUCKET, TEST_REGION, skip_if_missing + +if S3_DEPS_AVAILABLE: + import boto3 + from moto import mock_aws + + +@pytest.fixture +def s3_bucket(): + """Yield a mocked S3 bucket name, set up and torn down around each test.""" + skip_if_missing() + with mock_aws(): + client = boto3.client("s3", region_name=TEST_REGION) + client.create_bucket(Bucket=TEST_BUCKET) + yield TEST_BUCKET + + +@pytest.fixture +def s3_client(s3_bucket): + """Yield a boto3 S3 client within the mocked AWS context.""" + # s3_bucket fixture already sets up the mock_aws context manager; + # we just need to return a client pointing at the same mock. + return boto3.client("s3", region_name=TEST_REGION) diff --git a/tests/s3_tests/helpers.py b/tests/s3_tests/helpers.py new file mode 100644 index 00000000..d9a4724a --- /dev/null +++ b/tests/s3_tests/helpers.py @@ -0,0 +1,32 @@ +"""S3 test helpers.""" + +import pytest + +try: + import boto3 + import moto + + S3_DEPS_AVAILABLE = True +except ImportError: + boto3 = None # type: ignore[assignment] + moto = None # type: ignore[assignment] + S3_DEPS_AVAILABLE = False + +TEST_BUCKET = "cachier-test-bucket" +TEST_REGION = "us-east-1" + + +def skip_if_missing(): + """Skip the test if boto3 or moto are not installed.""" + if not S3_DEPS_AVAILABLE: + pytest.skip("boto3 and moto are required for S3 tests") + + +def make_s3_client(endpoint_url=None): + """Return a boto3 S3 client pointed at the moto mock or a custom endpoint.""" + if not S3_DEPS_AVAILABLE: + pytest.skip("boto3 and moto are required for S3 tests") + kwargs = {"region_name": TEST_REGION} + if endpoint_url: + kwargs["endpoint_url"] = endpoint_url + return boto3.client("s3", **kwargs) diff --git a/tests/s3_tests/test_s3_core.py b/tests/s3_tests/test_s3_core.py new file mode 100644 index 00000000..de9e7132 --- /dev/null +++ b/tests/s3_tests/test_s3_core.py @@ -0,0 +1,779 @@ +"""Tests for the S3 caching core.""" + +import asyncio +import contextlib +import threading +import warnings +from datetime import datetime, timedelta +from random import random +from unittest.mock import Mock + +import pytest + +from cachier.config import CacheEntry +from cachier.cores.base import RecalculationNeeded +from tests.s3_tests.helpers import S3_DEPS_AVAILABLE, TEST_BUCKET, TEST_REGION, skip_if_missing + +if S3_DEPS_AVAILABLE: + import boto3 + import botocore + from moto import mock_aws + + from cachier import cachier + from cachier.cores.s3 import MissingS3Bucket, _S3Core + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_core(s3_bucket=TEST_BUCKET, s3_client=None, s3_prefix="cachier"): + """Return a bare _S3Core (set_func must still be called before use).""" + skip_if_missing() + return _S3Core( + hash_func=None, + s3_bucket=s3_bucket, + s3_prefix=s3_prefix, + s3_client=s3_client, + ) + + +# --------------------------------------------------------------------------- +# Basic construction and validation +# --------------------------------------------------------------------------- + + +@pytest.mark.s3 +def test_missing_bucket_raises(): + skip_if_missing() + with pytest.raises(MissingS3Bucket): + _S3Core(hash_func=None, s3_bucket=None) + + +@pytest.mark.s3 +def test_missing_bucket_empty_string_raises(): + skip_if_missing() + with pytest.raises(MissingS3Bucket): + _S3Core(hash_func=None, s3_bucket="") + + +@pytest.mark.s3 +def test_missing_boto3_warns(monkeypatch): + skip_if_missing() + import cachier.cores.s3 as s3_mod + + monkeypatch.setattr(s3_mod, "BOTO3_AVAILABLE", False) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + with contextlib.suppress(Exception): + _S3Core(hash_func=None, s3_bucket="bucket") + assert any("boto3" in str(w.message) for w in caught) + + +# --------------------------------------------------------------------------- +# Core caching behaviour +# --------------------------------------------------------------------------- + + +@pytest.mark.s3 +def test_s3_core_basic_caching(s3_bucket): + """Cached result is returned on the second call.""" + + @cachier(backend="s3", s3_bucket=s3_bucket, s3_region=TEST_REGION) + def _cached(x): + return random() + x + + _cached.clear_cache() + val1 = _cached(1) + val2 = _cached(1) + assert val1 == val2 + + +@pytest.mark.s3 +def test_s3_core_different_args(s3_bucket): + """Different arguments produce different cache entries.""" + + @cachier(backend="s3", s3_bucket=s3_bucket, s3_region=TEST_REGION) + def _cached(x): + return random() + x + + _cached.clear_cache() + val1 = _cached(1) + val2 = _cached(2) + assert val1 != val2 + + +@pytest.mark.s3 +def test_s3_core_skip_cache(s3_bucket): + """cachier__skip_cache bypasses the cache.""" + + @cachier(backend="s3", s3_bucket=s3_bucket, s3_region=TEST_REGION) + def _cached(x): + return random() + + _cached.clear_cache() + val1 = _cached(1) + val2 = _cached(1, cachier__skip_cache=True) + assert val1 != val2 + + +@pytest.mark.s3 +def test_s3_core_overwrite_cache(s3_bucket): + """cachier__overwrite_cache forces recalculation.""" + + @cachier(backend="s3", s3_bucket=s3_bucket, s3_region=TEST_REGION) + def _cached(x): + return random() + + _cached.clear_cache() + val1 = _cached(1) + val2 = _cached(1, cachier__overwrite_cache=True) + val3 = _cached(1) + assert val1 != val2 + assert val2 == val3 + + +@pytest.mark.s3 +def test_s3_core_stale_after(s3_bucket): + """A result older than stale_after is recomputed.""" + + @cachier(backend="s3", s3_bucket=s3_bucket, s3_region=TEST_REGION, stale_after=timedelta(seconds=1)) + def _cached(x): + return random() + + _cached.clear_cache() + val1 = _cached(1) + import time + + time.sleep(2) + val2 = _cached(1) + assert val1 != val2 + + +@pytest.mark.s3 +def test_s3_core_next_time(s3_bucket): + """With next_time=True, stale result is returned immediately.""" + + @cachier( + backend="s3", + s3_bucket=s3_bucket, + s3_region=TEST_REGION, + stale_after=timedelta(seconds=1), + next_time=True, + ) + def _cached(x): + return random() + + _cached.clear_cache() + val1 = _cached(1) + import time + + time.sleep(2) + val2 = _cached(1) # stale; should return old value immediately + assert val1 == val2 + + +@pytest.mark.s3 +def test_s3_core_allow_none(s3_bucket): + """None results are cached when allow_none=True.""" + call_count = [0] + + @cachier(backend="s3", s3_bucket=s3_bucket, s3_region=TEST_REGION, allow_none=True) + def _cached(x): + call_count[0] += 1 + return None + + _cached.clear_cache() + res1 = _cached(1) + res2 = _cached(1) + assert res1 is None + assert res2 is None + assert call_count[0] == 1 + + +@pytest.mark.s3 +@pytest.mark.asyncio +async def test_s3_core_async_cache_hit(s3_bucket): + """Async S3 cache calls should hit cache on repeated arguments.""" + call_count = [0] + + @cachier(backend="s3", s3_bucket=s3_bucket, s3_region=TEST_REGION) + async def _cached(x): + call_count[0] += 1 + await asyncio.sleep(0) + return x * 5 + + await _cached.clear_cache() + value_1 = await _cached(6) + value_2 = await _cached(6) + + assert value_1 == 30 + assert value_2 == 30 + assert call_count[0] == 1 + + +@pytest.mark.s3 +@pytest.mark.asyncio +async def test_s3_core_async_get_entry_by_key_missing(s3_bucket): + """aget_entry_by_key delegates correctly and returns missing entries as None.""" + skip_if_missing() + core = _make_core(s3_bucket=s3_bucket) + + def _dummy(x): + return x + + core.set_func(_dummy) + key = core.get_key((), {"x": 123}) + returned_key, entry = await core.aget_entry_by_key(key) + assert returned_key == key + assert entry is None + + +@pytest.mark.s3 +def test_s3_core_none_not_cached_without_allow_none(s3_bucket): + """None results are NOT cached when allow_none=False (default).""" + call_count = [0] + + @cachier(backend="s3", s3_bucket=s3_bucket, s3_region=TEST_REGION, allow_none=False) + def _cached(x): + call_count[0] += 1 + return None + + _cached.clear_cache() + _cached(1) + _cached(1) + assert call_count[0] == 2 + + +@pytest.mark.s3 +def test_s3_core_clear_cache(s3_bucket): + """clear_cache removes all entries so the next call recomputes.""" + call_count = [0] + + @cachier(backend="s3", s3_bucket=s3_bucket, s3_region=TEST_REGION) + def _cached(x): + call_count[0] += 1 + return x * 2 + + _cached.clear_cache() + _cached(5) + _cached(5) + assert call_count[0] == 1 + _cached.clear_cache() + _cached(5) + assert call_count[0] == 2 + + +@pytest.mark.s3 +def test_s3_core_clear_being_calculated(s3_bucket): + """clear_being_calculated resets the processing flag on all entries.""" + + @cachier(backend="s3", s3_bucket=s3_bucket, s3_region=TEST_REGION) + def _cached(x): + return x + + _cached.clear_cache() + _cached(1) + _cached.clear_being_calculated() # should not raise + + +# --------------------------------------------------------------------------- +# entry_size_limit +# --------------------------------------------------------------------------- + + +@pytest.mark.s3 +def test_s3_entry_size_limit(s3_bucket): + """Results larger than entry_size_limit are not cached.""" + call_count = [0] + + @cachier(backend="s3", s3_bucket=s3_bucket, s3_region=TEST_REGION, entry_size_limit=1) + def _cached(x): + call_count[0] += 1 + return list(range(1000)) + + _cached.clear_cache() + _cached(1) + _cached(1) + assert call_count[0] == 2 + + +# --------------------------------------------------------------------------- +# delete_stale_entries +# --------------------------------------------------------------------------- + + +@pytest.mark.s3 +def test_s3_delete_stale_entries(s3_bucket): + """delete_stale_entries removes entries that are older than stale_after.""" + call_count = [0] + + @cachier(backend="s3", s3_bucket=s3_bucket, s3_region=TEST_REGION) + def _cached(x): + call_count[0] += 1 + return x + + _cached.clear_cache() + _cached(1) + assert call_count[0] == 1 + + import time + + time.sleep(1) + # Manually trigger stale entry cleanup for entries older than 0 seconds + from cachier.cores.s3 import _S3Core + + # Access core via closure - we delete with a tiny stale_after so the entry qualifies + client = boto3.client("s3", region_name=TEST_REGION) + s3_core = _S3Core( + hash_func=None, + s3_bucket=s3_bucket, + s3_client=client, + ) + s3_core.set_func(_cached.__wrapped__ if hasattr(_cached, "__wrapped__") else lambda x: x) + s3_core.delete_stale_entries(timedelta(seconds=0)) + + # After deleting we expect a recompute + _cached(1) + assert call_count[0] == 2 + + +# --------------------------------------------------------------------------- +# s3_client_factory +# --------------------------------------------------------------------------- + + +@pytest.mark.s3 +def test_s3_client_factory(s3_bucket): + """s3_client_factory is called each time to obtain the S3 client.""" + factory_calls = [0] + + def my_factory(): + factory_calls[0] += 1 + return boto3.client("s3", region_name=TEST_REGION) + + @cachier(backend="s3", s3_bucket=s3_bucket, s3_client_factory=my_factory) + def _cached(x): + return x * 3 + + _cached.clear_cache() + _cached(7) + _cached(7) + assert factory_calls[0] > 0 + + +# --------------------------------------------------------------------------- +# Thread safety (basic smoke test) +# --------------------------------------------------------------------------- + + +@pytest.mark.s3 +def test_s3_core_threadsafe(s3_bucket): + """Multiple threads calling the same cached function all see the same result.""" + results = [] + + @cachier(backend="s3", s3_bucket=s3_bucket, s3_region=TEST_REGION) + def _cached(x): + return x * 10 + + _cached.clear_cache() + + def _call(): + results.append(_cached(3)) + + threads = [threading.Thread(target=_call) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert all(r == 30 for r in results) + + +# --------------------------------------------------------------------------- +# Error handling / warnings +# --------------------------------------------------------------------------- + + +@pytest.mark.s3 +def test_s3_bad_bucket_warns(): + """Operations against a non-existent bucket emit a warning rather than crashing.""" + skip_if_missing() + with mock_aws(): + # Intentionally do NOT create the bucket + client = boto3.client("s3", region_name=TEST_REGION) + + core = _S3Core(hash_func=None, s3_bucket="nonexistent-bucket", s3_client=client) + + def _dummy(x): + return x + + core.set_func(_dummy) + key = core.get_key((), {"x": 1}) + + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + result_key, entry = core.get_entry_by_key(key) + # Should return None (not found / error) without raising + assert entry is None + assert result_key == key + + +@pytest.mark.s3 +def test_core_decorator_constructs_s3_core(monkeypatch): + """The public decorator routes backend='s3' through the _S3Core branch.""" + skip_if_missing() + import cachier.core as core_mod + + captured = {} + + class DummyS3Core: + """Minimal drop-in core used to verify constructor wiring.""" + + def __init__(self, **kwargs): + captured.update(kwargs) + + def set_func(self, func): + self.func = func + self.func_is_method = False + + def get_key(self, args, kwds): + return "dummy-key" + + def get_entry(self, args, kwds): + return "dummy-key", None + + def set_entry(self, key, func_res): + return True + + def mark_entry_being_calculated(self, key): + return None + + def mark_entry_not_calculated(self, key): + return None + + def wait_on_entry_calc(self, key): + return 42 + + def clear_cache(self): + return None + + def clear_being_calculated(self): + return None + + def delete_stale_entries(self, stale_after): + return None + + monkeypatch.setattr(core_mod, "_S3Core", DummyS3Core) + + @core_mod.cachier(backend="s3", s3_bucket="bucket", s3_prefix="prefix") + def decorated(x): + return x + + assert decorated(3) == 3 + assert captured["s3_bucket"] == "bucket" + assert captured["s3_prefix"] == "prefix" + + +@pytest.mark.s3 +def test_s3_internal_helpers_and_error_paths(monkeypatch): + """Exercise internal helper branches and warning paths.""" + skip_if_missing() + + class DummyFactoryClient: + pass + + factory_client = DummyFactoryClient() + + core = _S3Core( + hash_func=None, + s3_bucket=TEST_BUCKET, + s3_prefix="my-prefix", + s3_client_factory=lambda: factory_client, + ) + + def _dummy(x): + return x + + core.set_func(_dummy) + + assert core._get_s3_client() is factory_client + assert core._get_s3_key("abc") == "my-prefix/.tests.s3_tests.test_s3_core._dummy/abc.pkl" + assert core._get_s3_prefix() == "my-prefix/.tests.s3_tests.test_s3_core._dummy/" + + bad_pickle = b"not a pickle" + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + assert core._load_entry(bad_pickle) is None + assert any("deserialization failed" in str(w.message) for w in caught) + + invalid_data = __import__("pickle").dumps(1) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + assert core._load_entry(invalid_data) is None + assert any("CacheEntry construction failed" in str(w.message) for w in caught) + + +@pytest.mark.s3 +def test_s3_get_client_builds_boto3_client(monkeypatch): + """_get_s3_client forwards region/endpoint/config kwargs when auto-creating.""" + skip_if_missing() + import cachier.cores.s3 as s3_mod + + created = {} + + def fake_client(name, **kwargs): + created["name"] = name + created["kwargs"] = kwargs + return "fake-client" + + monkeypatch.setattr(s3_mod.boto3, "client", fake_client) + + core = _S3Core( + hash_func=None, + s3_bucket=TEST_BUCKET, + s3_region="us-west-2", + s3_endpoint_url="http://localhost:9000", + s3_config=object(), + ) + + assert core._get_s3_client() == "fake-client" + assert created["name"] == "s3" + assert set(created["kwargs"]) == {"region_name", "endpoint_url", "config"} + + created.clear() + no_options_core = _S3Core(hash_func=None, s3_bucket=TEST_BUCKET) + assert no_options_core._get_s3_client() == "fake-client" + assert created["kwargs"] == {} + + +@pytest.mark.s3 +def test_s3_get_entry_by_key_branches(monkeypatch): + """get_entry_by_key handles success and exception branches.""" + skip_if_missing() + + core = _make_core(s3_client=Mock()) + + def _dummy(x): + return x + + core.set_func(_dummy) + + payload = core._dump_entry( + CacheEntry( + value=5, + time=datetime.now(), + stale=False, + _processing=False, + _completed=True, + ) + ) + good_body = Mock(read=Mock(return_value=payload)) + core._s3_client.get_object = Mock(return_value={"Body": good_body}) + _, entry = core.get_entry_by_key("k") + assert entry is not None + assert entry.value == 5 + + client_error = botocore.exceptions.ClientError( + {"Error": {"Code": "NoSuchKey", "Message": "missing"}}, + "GetObject", + ) + core._s3_client.get_object = Mock(side_effect=client_error) + assert core.get_entry_by_key("k") == ("k", None) + + bad_client_error = botocore.exceptions.ClientError( + {"Error": {"Code": "AccessDenied", "Message": "denied"}}, + "GetObject", + ) + core._s3_client.get_object = Mock(side_effect=bad_client_error) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + assert core.get_entry_by_key("k") == ("k", None) + assert any("get_entry_by_key failed" in str(w.message) for w in caught) + + core._s3_client.get_object = Mock(side_effect=RuntimeError("boom")) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + assert core.get_entry_by_key("k") == ("k", None) + assert any("get_entry_by_key failed" in str(w.message) for w in caught) + + +@pytest.mark.s3 +def test_s3_set_mark_wait_and_clear_paths(monkeypatch): + """Exercise set/mark/wait and clear path branches.""" + skip_if_missing() + import datetime as dt + + client = Mock() + core = _make_core(s3_client=client) + + def _dummy(x): + return x + + core.set_func(_dummy) + + monkeypatch.setattr(core, "_should_store", lambda _: False) + assert core.set_entry("k", 1) is False + + monkeypatch.setattr(core, "_should_store", lambda _: True) + client.put_object = Mock(return_value=None) + assert core.set_entry("k", 1) is True + + client.put_object = Mock(side_effect=RuntimeError("put failed")) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + assert core.set_entry("k", 1) is False + assert any("set_entry failed" in str(w.message) for w in caught) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + core.mark_entry_being_calculated("k") + assert any("mark_entry_being_calculated failed" in str(w.message) for w in caught) + + entry = CacheEntry(value=3, time=dt.datetime.now(), stale=False, _processing=True, _completed=False) + client.get_object = Mock(return_value={"Body": Mock(read=Mock(return_value=core._dump_entry(entry)))}) + client.put_object = Mock(return_value=None) + core.mark_entry_not_calculated("k") + assert client.put_object.called + + no_such_key = botocore.exceptions.ClientError({"Error": {"Code": "NoSuchKey"}}, "GetObject") + client.get_object = Mock(side_effect=no_such_key) + core.mark_entry_not_calculated("k") + + denied = botocore.exceptions.ClientError({"Error": {"Code": "AccessDenied"}}, "GetObject") + client.get_object = Mock(side_effect=denied) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + core.mark_entry_not_calculated("k") + assert any("mark_entry_not_calculated failed" in str(w.message) for w in caught) + + client.get_object = Mock(return_value={"Body": Mock(read=Mock(return_value=b"bad"))}) + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + core.mark_entry_not_calculated("k") + + client.get_object = Mock(side_effect=RuntimeError("unexpected")) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + core.mark_entry_not_calculated("k") + assert any("mark_entry_not_calculated failed" in str(w.message) for w in caught) + + monkeypatch.setattr("cachier.cores.s3.time.sleep", lambda _: None) + sequence = [ + ("k", CacheEntry(value=None, time=dt.datetime.now(), stale=False, _processing=True, _completed=False)), + ("k", CacheEntry(value=9, time=dt.datetime.now(), stale=False, _processing=False, _completed=True)), + ] + monkeypatch.setattr(core, "get_entry_by_key", lambda key: sequence.pop(0)) + assert core.wait_on_entry_calc("k") == 9 + + monkeypatch.setattr(core, "get_entry_by_key", lambda key: ("k", None)) + with pytest.raises(RecalculationNeeded): + core.wait_on_entry_calc("k") + + paginator = Mock() + paginator.paginate.return_value = [{"Contents": [{"Key": f"k{i}"} for i in range(1001)]}] + client.get_paginator = Mock(return_value=paginator) + client.delete_objects = Mock(return_value=None) + core.clear_cache() + assert client.delete_objects.call_count == 2 + + client.get_paginator = Mock(side_effect=RuntimeError("paginate failed")) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + core.clear_cache() + assert any("clear_cache failed" in str(w.message) for w in caught) + + +@pytest.mark.s3 +def test_s3_clear_processing_and_delete_stale_paths(monkeypatch): + """Exercise clear_being_calculated and delete_stale_entries branches.""" + skip_if_missing() + import datetime as dt + + client = Mock() + core = _make_core(s3_client=client) + + def _dummy(x): + return x + + core.set_func(_dummy) + + now = dt.datetime.now() + processing = CacheEntry(value=1, time=now, stale=False, _processing=True, _completed=False) + done = CacheEntry(value=2, time=now, stale=False, _processing=False, _completed=True) + + paginator = Mock() + paginator.paginate.return_value = [{"Contents": [{"Key": "k1"}, {"Key": "k2"}]}] + client.get_paginator = Mock(return_value=paginator) + client.get_object = Mock( + side_effect=[ + {"Body": Mock(read=Mock(return_value=core._dump_entry(processing)))}, + {"Body": Mock(read=Mock(return_value=core._dump_entry(done)))}, + ] + ) + client.put_object = Mock(return_value=None) + core.clear_being_calculated() + assert client.put_object.call_count == 1 + + client.get_paginator = Mock(return_value=paginator) + client.get_object = Mock(side_effect=RuntimeError("entry read failed")) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + core.clear_being_calculated() + assert any("clear_being_calculated entry update failed" in str(w.message) for w in caught) + + client.get_paginator = Mock(side_effect=RuntimeError("outer failed")) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + core.clear_being_calculated() + assert any("clear_being_calculated failed" in str(w.message) for w in caught) + + old = CacheEntry(value=1, time=now - dt.timedelta(days=2), stale=False, _processing=False, _completed=True) + fresh = CacheEntry(value=2, time=now, stale=False, _processing=False, _completed=True) + client.get_paginator = Mock(return_value=paginator) + client.get_object = Mock( + side_effect=[ + {"Body": Mock(read=Mock(return_value=core._dump_entry(old)))}, + {"Body": Mock(read=Mock(return_value=core._dump_entry(fresh)))}, + ] + ) + client.delete_object = Mock(return_value=None) + core.delete_stale_entries(dt.timedelta(days=1)) + assert client.delete_object.call_count == 1 + + client.get_paginator = Mock(return_value=paginator) + client.get_object = Mock(side_effect=RuntimeError("entry read failed")) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + core.delete_stale_entries(dt.timedelta(days=1)) + assert any("delete_stale_entries entry check failed" in str(w.message) for w in caught) + + client.get_paginator = Mock(side_effect=RuntimeError("outer failed")) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + core.delete_stale_entries(dt.timedelta(days=1)) + assert any("delete_stale_entries failed" in str(w.message) for w in caught) + + +@pytest.mark.s3 +def test_s3_module_importerror_branch(monkeypatch): + """The module sets BOTO3_AVAILABLE=False when boto3 import fails.""" + import builtins + import importlib.util + from pathlib import Path + + source = Path("src/cachier/cores/s3.py") + spec = importlib.util.spec_from_file_location("cachier.cores.s3_no_boto", source) + module = importlib.util.module_from_spec(spec) + original_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name in {"boto3", "botocore.exceptions"}: + raise ImportError("missing for test") + return original_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + assert spec.loader is not None + spec.loader.exec_module(module) + assert module.BOTO3_AVAILABLE is False diff --git a/tests/test_async_core.py b/tests/test_async_core.py index 60b681a0..5adeedb9 100644 --- a/tests/test_async_core.py +++ b/tests/test_async_core.py @@ -358,6 +358,27 @@ def sync_func(x): await sync_func.aclear_cache() +@pytest.mark.memory +@pytest.mark.asyncio +class TestAsyncWrapperMaintenanceMethods: + """Tests for clear helpers exposed on async wrappers.""" + + async def test_clear_methods_are_await_safe(self): + """Async wrappers support both sync and awaited clear_cache usage.""" + + @cachier(backend="memory") + async def async_func(x): + return x + + # Legacy sync usage should keep working. + async_func.clear_cache() + async_func.clear_being_calculated() + + # Awaiting these methods should also work. + await async_func.clear_cache() + await async_func.clear_being_calculated() + + # ============================================================================= # Argument Handling Tests # =============================================================================