-
Notifications
You must be signed in to change notification settings - Fork 73
Add NumPy-aware default hashing, regression test, and xxhash benchmark #337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
6a149bc
3958418
8c94ac3
c33c6be
b2ba0ba
07dd62c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| """Benchmark default Cachier hashing against xxhash for large NumPy arrays.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import argparse | ||
| import pickle | ||
| import statistics | ||
| import time | ||
| from typing import Any, Callable, Dict, List | ||
|
|
||
| import numpy as np | ||
|
|
||
| from cachier.config import _default_hash_func | ||
|
|
||
|
|
||
| def _xxhash_numpy_hash(args: tuple[Any, ...], kwds: dict[str, Any]) -> str: | ||
| """Hash call arguments with xxhash, optimized for NumPy arrays. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| args : tuple[Any, ...] | ||
| Positional arguments. | ||
| kwds : dict[str, Any] | ||
| Keyword arguments. | ||
|
|
||
| Returns | ||
| ------- | ||
| str | ||
| xxhash hex digest. | ||
|
|
||
| """ | ||
| import xxhash | ||
|
|
||
| hasher = xxhash.xxh64() | ||
| hasher.update(b"args") | ||
| for value in args: | ||
| if isinstance(value, np.ndarray): | ||
| hasher.update(value.dtype.str.encode("utf-8")) | ||
| hasher.update(str(value.shape).encode("utf-8")) | ||
| hasher.update(value.tobytes(order="C")) | ||
| else: | ||
| hasher.update(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)) | ||
|
|
||
| hasher.update(b"kwds") | ||
| for key, value in sorted(kwds.items()): | ||
| hasher.update(pickle.dumps(key, protocol=pickle.HIGHEST_PROTOCOL)) | ||
| if isinstance(value, np.ndarray): | ||
| hasher.update(value.dtype.str.encode("utf-8")) | ||
| hasher.update(str(value.shape).encode("utf-8")) | ||
| hasher.update(value.tobytes(order="C")) | ||
| else: | ||
| hasher.update(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)) | ||
|
|
||
| return hasher.hexdigest() | ||
|
|
||
|
|
||
| def _benchmark(hash_func: Callable[[tuple[Any, ...], dict[str, Any]], str], args: tuple[Any, ...], runs: int) -> float: | ||
| durations: List[float] = [] | ||
| for _ in range(runs): | ||
| start = time.perf_counter() | ||
| hash_func(args, {}) | ||
| durations.append(time.perf_counter() - start) | ||
| return statistics.median(durations) | ||
|
|
||
|
|
||
| def main() -> None: | ||
| """Run benchmark comparing cachier default hashing with xxhash.""" | ||
| parser = argparse.ArgumentParser(description=__doc__) | ||
| parser.add_argument( | ||
| "--elements", | ||
| type=int, | ||
| default=10_000_000, | ||
| help="Number of float64 elements in the benchmark array", | ||
| ) | ||
| parser.add_argument("--runs", type=int, default=7, help="Number of benchmark runs") | ||
| parsed = parser.parse_args() | ||
|
|
||
| try: | ||
| import xxhash # noqa: F401 | ||
| except ImportError as error: | ||
| raise SystemExit("Missing dependency: xxhash. Install with `pip install xxhash`.") from error | ||
|
|
||
| array = np.arange(parsed.elements, dtype=np.float64) | ||
| args = (array,) | ||
|
|
||
| results: Dict[str, float] = { | ||
| "cachier_default": _benchmark(_default_hash_func, args, parsed.runs), | ||
| "xxhash_reference": _benchmark(_xxhash_numpy_hash, args, parsed.runs), | ||
| } | ||
|
|
||
| ratio = results["cachier_default"] / results["xxhash_reference"] | ||
|
|
||
| print(f"Array elements: {parsed.elements:,}") | ||
| print(f"Array bytes: {array.nbytes:,}") | ||
| print(f"Runs: {parsed.runs}") | ||
| print(f"cachier_default median: {results['cachier_default']:.6f}s") | ||
| print(f"xxhash_reference median: {results['xxhash_reference']:.6f}s") | ||
| print(f"ratio (cachier_default / xxhash_reference): {ratio:.2f}x") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,13 +9,140 @@ | |
| from ._types import Backend, HashFunc, Mongetter | ||
|
|
||
|
|
||
| def _is_numpy_array(value: Any) -> bool: | ||
| """Check whether a value is a NumPy ndarray without importing NumPy eagerly. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| value : Any | ||
| The value to inspect. | ||
|
|
||
| Returns | ||
| ------- | ||
| bool | ||
| True when ``value`` is a NumPy ndarray instance. | ||
|
|
||
| """ | ||
| return type(value).__module__ == "numpy" and type(value).__name__ == "ndarray" | ||
|
|
||
|
|
||
| def _hash_numpy_array(hasher: "hashlib._Hash", value: Any) -> None: | ||
| """Update hasher with NumPy array metadata and buffer content. | ||
|
|
||
| The array content is converted to bytes using C-order (row-major) layout | ||
| to ensure consistent hashing regardless of memory layout. This operation | ||
| may create a copy if the array is not already C-contiguous (e.g., for | ||
| transposed arrays, sliced views, or Fortran-ordered arrays), which has | ||
| performance implications for large arrays. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| hasher : hashlib._Hash | ||
| The hasher to update. | ||
| value : Any | ||
| A NumPy ndarray instance. | ||
|
|
||
| Notes | ||
| ----- | ||
| The ``tobytes(order="C")`` call ensures deterministic hash values by | ||
| normalizing the memory layout, but may incur a memory copy for | ||
| non-contiguous arrays. For optimal performance with large arrays, | ||
| consider using C-contiguous arrays when possible. | ||
|
|
||
| """ | ||
| hasher.update(b"numpy.ndarray") | ||
| hasher.update(value.dtype.str.encode("utf-8")) | ||
| hasher.update(str(value.shape).encode("utf-8")) | ||
| hasher.update(value.tobytes(order="C")) | ||
|
|
||
|
|
||
| def _update_hash_for_value(hasher: "hashlib._Hash", value: Any, depth: int = 0, max_depth: int = 100) -> None: | ||
| """Update hasher with a stable representation of a Python value. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| hasher : hashlib._Hash | ||
| The hasher to update. | ||
| value : Any | ||
| Value to encode. | ||
| depth : int, optional | ||
| Current recursion depth (internal use only). | ||
| max_depth : int, optional | ||
| Maximum allowed recursion depth to prevent stack overflow. | ||
|
|
||
| Raises | ||
| ------ | ||
| RecursionError | ||
| If the recursion depth exceeds max_depth. | ||
|
|
||
| """ | ||
| if depth > max_depth: | ||
| raise RecursionError( | ||
| f"Maximum recursion depth ({max_depth}) exceeded while hashing nested " | ||
| f"data structure. Consider flattening your data or using a custom " | ||
| f"hash_func parameter." | ||
| ) | ||
|
|
||
| if _is_numpy_array(value): | ||
| _hash_numpy_array(hasher, value) | ||
| return | ||
|
|
||
| if isinstance(value, tuple): | ||
| hasher.update(b"tuple") | ||
| for item in value: | ||
| _update_hash_for_value(hasher, item, depth + 1, max_depth) | ||
| return | ||
|
|
||
| if isinstance(value, list): | ||
| hasher.update(b"list") | ||
| for item in value: | ||
| _update_hash_for_value(hasher, item, depth + 1, max_depth) | ||
| return | ||
|
|
||
| if isinstance(value, dict): | ||
| hasher.update(b"dict") | ||
| for dict_key in sorted(value): | ||
shaypal5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| _update_hash_for_value(hasher, dict_key, depth + 1, max_depth) | ||
| _update_hash_for_value(hasher, value[dict_key], depth + 1, max_depth) | ||
| return | ||
|
|
||
shaypal5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if isinstance(value, (set, frozenset)): | ||
| # Use a deterministic ordering of elements for hashing. | ||
| hasher.update(b"frozenset" if isinstance(value, frozenset) else b"set") | ||
| try: | ||
| # Fast path: works for homogeneous, orderable element types. | ||
| iterable = sorted(value) | ||
| except TypeError: | ||
| # Fallback: impose a deterministic order based on type name and repr. | ||
| iterable = sorted(value, key=lambda item: (type(item).__name__, repr(item))) | ||
| for item in iterable: | ||
| _update_hash_for_value(hasher, item) | ||
| return | ||
| hasher.update(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)) | ||
|
|
||
|
|
||
| def _default_hash_func(args, kwds): | ||
| # Sort the kwargs to ensure consistent ordering | ||
| sorted_kwargs = sorted(kwds.items()) | ||
| # Serialize args and sorted_kwargs using pickle or similar | ||
| serialized = pickle.dumps((args, sorted_kwargs)) | ||
| # Create a hash of the serialized data | ||
| return hashlib.sha256(serialized).hexdigest() | ||
| """Compute a stable hash key for function arguments. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| args : tuple | ||
| Positional arguments. | ||
| kwds : dict | ||
| Keyword arguments. | ||
|
|
||
| Returns | ||
| ------- | ||
| str | ||
| A hex digest representing the call arguments. | ||
|
|
||
| """ | ||
| hasher = hashlib.blake2b(digest_size=32) | ||
| hasher.update(b"args") | ||
| _update_hash_for_value(hasher, args) | ||
| hasher.update(b"kwds") | ||
| _update_hash_for_value(hasher, dict(sorted(kwds.items()))) | ||
| return hasher.hexdigest() | ||
|
Comment on lines
124
to
+145
|
||
|
|
||
|
|
||
| def _default_cache_dir(): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| """Tests for NumPy-aware default hash behavior.""" | ||
|
|
||
| from datetime import timedelta | ||
|
|
||
| import pytest | ||
|
|
||
| from cachier import cachier | ||
|
|
||
| np = pytest.importorskip("numpy") | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "backend", | ||
| [ | ||
| pytest.param("memory", marks=pytest.mark.memory), | ||
| pytest.param("pickle", marks=pytest.mark.pickle), | ||
| ], | ||
| ) | ||
| def test_default_hash_func_uses_array_content_for_cache_keys(backend, tmp_path): | ||
| """Verify equal arrays map to a cache hit and different arrays miss.""" | ||
| call_count = 0 | ||
|
|
||
| decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)} | ||
| if backend == "pickle": | ||
| decorator_kwargs["cache_dir"] = tmp_path | ||
|
|
||
| @cachier(**decorator_kwargs) | ||
| def array_sum(values): | ||
| nonlocal call_count | ||
| call_count += 1 | ||
| return int(values.sum()) | ||
|
|
||
| arr = np.arange(100_000, dtype=np.int64) | ||
| arr_copy = arr.copy() | ||
| changed = arr.copy() | ||
| changed[-1] = -1 | ||
|
|
||
| first = array_sum(arr) | ||
| assert call_count == 1 | ||
|
|
||
| second = array_sum(arr_copy) | ||
| assert second == first | ||
| assert call_count == 1 | ||
|
|
||
| third = array_sum(changed) | ||
| assert third != first | ||
| assert call_count == 2 | ||
|
|
||
| array_sum.clear_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
tobytes(order="C")call creates a copy if the array is not already C-contiguous, which is correct for ensuring consistent hashing. However, consider documenting this behavior in the function docstring, as it has performance implications for large non-contiguous arrays (e.g., sliced views, transposed arrays). This is working as intended but worth noting for users.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot open a new pull request to apply changes based on this feedback