diff --git a/scripts/benchmark_numpy_hash.py b/scripts/benchmark_numpy_hash.py new file mode 100644 index 00000000..8a6a3980 --- /dev/null +++ b/scripts/benchmark_numpy_hash.py @@ -0,0 +1,102 @@ +"""Benchmark default Cachier hashing against xxhash for large NumPy arrays.""" + +from __future__ import annotations + +import argparse +import pickle +import statistics +import time +from typing import Any, Callable, Dict, List + +import numpy as np + +from cachier.config import _default_hash_func + + +def _xxhash_numpy_hash(args: tuple[Any, ...], kwds: dict[str, Any]) -> str: + """Hash call arguments with xxhash, optimized for NumPy arrays. + + Parameters + ---------- + args : tuple[Any, ...] + Positional arguments. + kwds : dict[str, Any] + Keyword arguments. + + Returns + ------- + str + xxhash hex digest. + + """ + import xxhash + + hasher = xxhash.xxh64() + hasher.update(b"args") + for value in args: + if isinstance(value, np.ndarray): + hasher.update(value.dtype.str.encode("utf-8")) + hasher.update(str(value.shape).encode("utf-8")) + hasher.update(value.tobytes(order="C")) + else: + hasher.update(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)) + + hasher.update(b"kwds") + for key, value in sorted(kwds.items()): + hasher.update(pickle.dumps(key, protocol=pickle.HIGHEST_PROTOCOL)) + if isinstance(value, np.ndarray): + hasher.update(value.dtype.str.encode("utf-8")) + hasher.update(str(value.shape).encode("utf-8")) + hasher.update(value.tobytes(order="C")) + else: + hasher.update(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)) + + return hasher.hexdigest() + + +def _benchmark(hash_func: Callable[[tuple[Any, ...], dict[str, Any]], str], args: tuple[Any, ...], runs: int) -> float: + durations: List[float] = [] + for _ in range(runs): + start = time.perf_counter() + hash_func(args, {}) + durations.append(time.perf_counter() - start) + return statistics.median(durations) + + +def main() -> None: + """Run benchmark comparing cachier default hashing with xxhash.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--elements", + type=int, + default=10_000_000, + help="Number of float64 elements in the benchmark array", + ) + parser.add_argument("--runs", type=int, default=7, help="Number of benchmark runs") + parsed = parser.parse_args() + + try: + import xxhash # noqa: F401 + except ImportError as error: + raise SystemExit("Missing dependency: xxhash. Install with `pip install xxhash`.") from error + + array = np.arange(parsed.elements, dtype=np.float64) + args = (array,) + + results: Dict[str, float] = { + "cachier_default": _benchmark(_default_hash_func, args, parsed.runs), + "xxhash_reference": _benchmark(_xxhash_numpy_hash, args, parsed.runs), + } + + ratio = results["cachier_default"] / results["xxhash_reference"] + + print(f"Array elements: {parsed.elements:,}") + print(f"Array bytes: {array.nbytes:,}") + print(f"Runs: {parsed.runs}") + print(f"cachier_default median: {results['cachier_default']:.6f}s") + print(f"xxhash_reference median: {results['xxhash_reference']:.6f}s") + print(f"ratio (cachier_default / xxhash_reference): {ratio:.2f}x") + + +if __name__ == "__main__": + main() diff --git a/src/cachier/config.py b/src/cachier/config.py index f04d0cff..37968fcf 100644 --- a/src/cachier/config.py +++ b/src/cachier/config.py @@ -9,13 +9,140 @@ from ._types import Backend, HashFunc, Mongetter +def _is_numpy_array(value: Any) -> bool: + """Check whether a value is a NumPy ndarray without importing NumPy eagerly. + + Parameters + ---------- + value : Any + The value to inspect. + + Returns + ------- + bool + True when ``value`` is a NumPy ndarray instance. + + """ + return type(value).__module__ == "numpy" and type(value).__name__ == "ndarray" + + +def _hash_numpy_array(hasher: "hashlib._Hash", value: Any) -> None: + """Update hasher with NumPy array metadata and buffer content. + + The array content is converted to bytes using C-order (row-major) layout + to ensure consistent hashing regardless of memory layout. This operation + may create a copy if the array is not already C-contiguous (e.g., for + transposed arrays, sliced views, or Fortran-ordered arrays), which has + performance implications for large arrays. + + Parameters + ---------- + hasher : hashlib._Hash + The hasher to update. + value : Any + A NumPy ndarray instance. + + Notes + ----- + The ``tobytes(order="C")`` call ensures deterministic hash values by + normalizing the memory layout, but may incur a memory copy for + non-contiguous arrays. For optimal performance with large arrays, + consider using C-contiguous arrays when possible. + + """ + hasher.update(b"numpy.ndarray") + hasher.update(value.dtype.str.encode("utf-8")) + hasher.update(str(value.shape).encode("utf-8")) + hasher.update(value.tobytes(order="C")) + + +def _update_hash_for_value(hasher: "hashlib._Hash", value: Any, depth: int = 0, max_depth: int = 100) -> None: + """Update hasher with a stable representation of a Python value. + + Parameters + ---------- + hasher : hashlib._Hash + The hasher to update. + value : Any + Value to encode. + depth : int, optional + Current recursion depth (internal use only). + max_depth : int, optional + Maximum allowed recursion depth to prevent stack overflow. + + Raises + ------ + RecursionError + If the recursion depth exceeds max_depth. + + """ + if depth > max_depth: + raise RecursionError( + f"Maximum recursion depth ({max_depth}) exceeded while hashing nested " + f"data structure. Consider flattening your data or using a custom " + f"hash_func parameter." + ) + + if _is_numpy_array(value): + _hash_numpy_array(hasher, value) + return + + if isinstance(value, tuple): + hasher.update(b"tuple") + for item in value: + _update_hash_for_value(hasher, item, depth + 1, max_depth) + return + + if isinstance(value, list): + hasher.update(b"list") + for item in value: + _update_hash_for_value(hasher, item, depth + 1, max_depth) + return + + if isinstance(value, dict): + hasher.update(b"dict") + for dict_key in sorted(value): + _update_hash_for_value(hasher, dict_key, depth + 1, max_depth) + _update_hash_for_value(hasher, value[dict_key], depth + 1, max_depth) + return + + if isinstance(value, (set, frozenset)): + # Use a deterministic ordering of elements for hashing. + hasher.update(b"frozenset" if isinstance(value, frozenset) else b"set") + try: + # Fast path: works for homogeneous, orderable element types. + iterable = sorted(value) + except TypeError: + # Fallback: impose a deterministic order based on type name and repr. + iterable = sorted(value, key=lambda item: (type(item).__name__, repr(item))) + for item in iterable: + _update_hash_for_value(hasher, item) + return + hasher.update(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)) + + def _default_hash_func(args, kwds): - # Sort the kwargs to ensure consistent ordering - sorted_kwargs = sorted(kwds.items()) - # Serialize args and sorted_kwargs using pickle or similar - serialized = pickle.dumps((args, sorted_kwargs)) - # Create a hash of the serialized data - return hashlib.sha256(serialized).hexdigest() + """Compute a stable hash key for function arguments. + + Parameters + ---------- + args : tuple + Positional arguments. + kwds : dict + Keyword arguments. + + Returns + ------- + str + A hex digest representing the call arguments. + + """ + hasher = hashlib.blake2b(digest_size=32) + hasher.update(b"args") + _update_hash_for_value(hasher, args) + hasher.update(b"kwds") + _update_hash_for_value(hasher, dict(sorted(kwds.items()))) + return hasher.hexdigest() def _default_cache_dir(): diff --git a/tests/test_numpy_hash.py b/tests/test_numpy_hash.py new file mode 100644 index 00000000..90ee71c3 --- /dev/null +++ b/tests/test_numpy_hash.py @@ -0,0 +1,49 @@ +"""Tests for NumPy-aware default hash behavior.""" + +from datetime import timedelta + +import pytest + +from cachier import cachier + +np = pytest.importorskip("numpy") + + +@pytest.mark.parametrize( + "backend", + [ + pytest.param("memory", marks=pytest.mark.memory), + pytest.param("pickle", marks=pytest.mark.pickle), + ], +) +def test_default_hash_func_uses_array_content_for_cache_keys(backend, tmp_path): + """Verify equal arrays map to a cache hit and different arrays miss.""" + call_count = 0 + + decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)} + if backend == "pickle": + decorator_kwargs["cache_dir"] = tmp_path + + @cachier(**decorator_kwargs) + def array_sum(values): + nonlocal call_count + call_count += 1 + return int(values.sum()) + + arr = np.arange(100_000, dtype=np.int64) + arr_copy = arr.copy() + changed = arr.copy() + changed[-1] = -1 + + first = array_sum(arr) + assert call_count == 1 + + second = array_sum(arr_copy) + assert second == first + assert call_count == 1 + + third = array_sum(changed) + assert third != first + assert call_count == 2 + + array_sum.clear_cache() diff --git a/tests/test_recursion_depth.py b/tests/test_recursion_depth.py new file mode 100644 index 00000000..34795650 --- /dev/null +++ b/tests/test_recursion_depth.py @@ -0,0 +1,147 @@ +"""Tests for recursion depth protection in hash function.""" + +from datetime import timedelta + +import pytest + +from cachier import cachier + + +@pytest.mark.parametrize( + "backend", + [ + pytest.param("memory", marks=pytest.mark.memory), + pytest.param("pickle", marks=pytest.mark.pickle), + ], +) +def test_moderately_nested_structures_work(backend, tmp_path): + """Verify that moderately nested structures (< 100 levels) work fine.""" + call_count = 0 + + decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)} + if backend == "pickle": + decorator_kwargs["cache_dir"] = tmp_path + + @cachier(**decorator_kwargs) + def process_nested(data): + nonlocal call_count + call_count += 1 + return "processed" + + # Create a nested structure with 50 levels (well below the 100 limit) + nested_list = [] + current = nested_list + for _ in range(50): + inner = [] + current.append(inner) + current = inner + current.append("leaf") + + # Should work without issues + result1 = process_nested(nested_list) + assert result1 == "processed" + assert call_count == 1 + + # Second call should hit cache + result2 = process_nested(nested_list) + assert result2 == "processed" + assert call_count == 1 + + process_nested.clear_cache() + + +@pytest.mark.parametrize( + "backend", + [ + pytest.param("memory", marks=pytest.mark.memory), + pytest.param("pickle", marks=pytest.mark.pickle), + ], +) +def test_deeply_nested_structures_raise_error(backend, tmp_path): + """Verify that deeply nested structures (> 100 levels) raise RecursionError.""" + decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)} + if backend == "pickle": + decorator_kwargs["cache_dir"] = tmp_path + + @cachier(**decorator_kwargs) + def process_nested(data): + return "processed" + + # Create a nested structure with 150 levels (exceeds the 100 limit) + nested_list = [] + current = nested_list + for _ in range(150): + inner = [] + current.append(inner) + current = inner + current.append("leaf") + + # Should raise RecursionError with a clear message + with pytest.raises( + RecursionError, + match=r"Maximum recursion depth \(100\) exceeded while hashing nested", + ): + process_nested(nested_list) + + +@pytest.mark.parametrize( + "backend", + [ + pytest.param("memory", marks=pytest.mark.memory), + pytest.param("pickle", marks=pytest.mark.pickle), + ], +) +def test_nested_dicts_respect_depth_limit(backend, tmp_path): + """Verify that nested dictionaries also respect the depth limit.""" + decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)} + if backend == "pickle": + decorator_kwargs["cache_dir"] = tmp_path + + @cachier(**decorator_kwargs) + def process_dict(data): + return "processed" + + # Create nested dictionaries beyond the limit + nested_dict = {} + current = nested_dict + for i in range(150): + current[f"level_{i}"] = {} + current = current[f"level_{i}"] + current["leaf"] = "value" + + # Should raise RecursionError + with pytest.raises( + RecursionError, + match=r"Maximum recursion depth \(100\) exceeded while hashing nested", + ): + process_dict(nested_dict) + + +@pytest.mark.parametrize( + "backend", + [ + pytest.param("memory", marks=pytest.mark.memory), + pytest.param("pickle", marks=pytest.mark.pickle), + ], +) +def test_nested_tuples_respect_depth_limit(backend, tmp_path): + """Verify that nested tuples also respect the depth limit.""" + decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)} + if backend == "pickle": + decorator_kwargs["cache_dir"] = tmp_path + + @cachier(**decorator_kwargs) + def process_tuple(data): + return "processed" + + # Create nested tuples beyond the limit + nested_tuple = ("leaf",) + for _ in range(150): + nested_tuple = (nested_tuple,) + + # Should raise RecursionError + with pytest.raises( + RecursionError, + match=r"Maximum recursion depth \(100\) exceeded while hashing nested", + ): + process_tuple(nested_tuple)