Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions scripts/benchmark_numpy_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Benchmark default Cachier hashing against xxhash for large NumPy arrays."""

from __future__ import annotations

import argparse
import pickle
import statistics
import time
from typing import Any, Callable, Dict, List

import numpy as np

from cachier.config import _default_hash_func


def _xxhash_numpy_hash(args: tuple[Any, ...], kwds: dict[str, Any]) -> str:
"""Hash call arguments with xxhash, optimized for NumPy arrays.

Parameters
----------
args : tuple[Any, ...]
Positional arguments.
kwds : dict[str, Any]
Keyword arguments.

Returns
-------
str
xxhash hex digest.

"""
import xxhash

hasher = xxhash.xxh64()
hasher.update(b"args")
for value in args:
if isinstance(value, np.ndarray):
hasher.update(value.dtype.str.encode("utf-8"))
hasher.update(str(value.shape).encode("utf-8"))
hasher.update(value.tobytes(order="C"))
else:
hasher.update(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL))

hasher.update(b"kwds")
for key, value in sorted(kwds.items()):
hasher.update(pickle.dumps(key, protocol=pickle.HIGHEST_PROTOCOL))
if isinstance(value, np.ndarray):
hasher.update(value.dtype.str.encode("utf-8"))
hasher.update(str(value.shape).encode("utf-8"))
hasher.update(value.tobytes(order="C"))
else:
hasher.update(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL))

return hasher.hexdigest()


def _benchmark(hash_func: Callable[[tuple[Any, ...], dict[str, Any]], str], args: tuple[Any, ...], runs: int) -> float:
durations: List[float] = []
for _ in range(runs):
start = time.perf_counter()
hash_func(args, {})
durations.append(time.perf_counter() - start)
return statistics.median(durations)


def main() -> None:
"""Run benchmark comparing cachier default hashing with xxhash."""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--elements",
type=int,
default=10_000_000,
help="Number of float64 elements in the benchmark array",
)
parser.add_argument("--runs", type=int, default=7, help="Number of benchmark runs")
parsed = parser.parse_args()

try:
import xxhash # noqa: F401
except ImportError as error:
raise SystemExit("Missing dependency: xxhash. Install with `pip install xxhash`.") from error

array = np.arange(parsed.elements, dtype=np.float64)
args = (array,)

results: Dict[str, float] = {
"cachier_default": _benchmark(_default_hash_func, args, parsed.runs),
"xxhash_reference": _benchmark(_xxhash_numpy_hash, args, parsed.runs),
}

ratio = results["cachier_default"] / results["xxhash_reference"]

print(f"Array elements: {parsed.elements:,}")
print(f"Array bytes: {array.nbytes:,}")
print(f"Runs: {parsed.runs}")
print(f"cachier_default median: {results['cachier_default']:.6f}s")
print(f"xxhash_reference median: {results['xxhash_reference']:.6f}s")
print(f"ratio (cachier_default / xxhash_reference): {ratio:.2f}x")


if __name__ == "__main__":
main()
139 changes: 133 additions & 6 deletions src/cachier/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,140 @@
from ._types import Backend, HashFunc, Mongetter


def _is_numpy_array(value: Any) -> bool:
"""Check whether a value is a NumPy ndarray without importing NumPy eagerly.

Parameters
----------
value : Any
The value to inspect.

Returns
-------
bool
True when ``value`` is a NumPy ndarray instance.

"""
return type(value).__module__ == "numpy" and type(value).__name__ == "ndarray"


def _hash_numpy_array(hasher: "hashlib._Hash", value: Any) -> None:
"""Update hasher with NumPy array metadata and buffer content.

The array content is converted to bytes using C-order (row-major) layout
to ensure consistent hashing regardless of memory layout. This operation
may create a copy if the array is not already C-contiguous (e.g., for
transposed arrays, sliced views, or Fortran-ordered arrays), which has
performance implications for large arrays.

Parameters
----------
hasher : hashlib._Hash
The hasher to update.
value : Any
A NumPy ndarray instance.

Notes
-----
The ``tobytes(order="C")`` call ensures deterministic hash values by
normalizing the memory layout, but may incur a memory copy for
non-contiguous arrays. For optimal performance with large arrays,
consider using C-contiguous arrays when possible.

"""
hasher.update(b"numpy.ndarray")
hasher.update(value.dtype.str.encode("utf-8"))
hasher.update(str(value.shape).encode("utf-8"))
hasher.update(value.tobytes(order="C"))
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tobytes(order="C") call creates a copy if the array is not already C-contiguous, which is correct for ensuring consistent hashing. However, consider documenting this behavior in the function docstring, as it has performance implications for large non-contiguous arrays (e.g., sliced views, transposed arrays). This is working as intended but worth noting for users.

Copilot uses AI. Check for mistakes.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback



def _update_hash_for_value(hasher: "hashlib._Hash", value: Any, depth: int = 0, max_depth: int = 100) -> None:
"""Update hasher with a stable representation of a Python value.

Parameters
----------
hasher : hashlib._Hash
The hasher to update.
value : Any
Value to encode.
depth : int, optional
Current recursion depth (internal use only).
max_depth : int, optional
Maximum allowed recursion depth to prevent stack overflow.

Raises
------
RecursionError
If the recursion depth exceeds max_depth.

"""
if depth > max_depth:
raise RecursionError(
f"Maximum recursion depth ({max_depth}) exceeded while hashing nested "
f"data structure. Consider flattening your data or using a custom "
f"hash_func parameter."
)

if _is_numpy_array(value):
_hash_numpy_array(hasher, value)
return

if isinstance(value, tuple):
hasher.update(b"tuple")
for item in value:
_update_hash_for_value(hasher, item, depth + 1, max_depth)
return

if isinstance(value, list):
hasher.update(b"list")
for item in value:
_update_hash_for_value(hasher, item, depth + 1, max_depth)
return

if isinstance(value, dict):
hasher.update(b"dict")
for dict_key in sorted(value):
_update_hash_for_value(hasher, dict_key, depth + 1, max_depth)
_update_hash_for_value(hasher, value[dict_key], depth + 1, max_depth)
return

if isinstance(value, (set, frozenset)):
# Use a deterministic ordering of elements for hashing.
hasher.update(b"frozenset" if isinstance(value, frozenset) else b"set")
try:
# Fast path: works for homogeneous, orderable element types.
iterable = sorted(value)
except TypeError:
# Fallback: impose a deterministic order based on type name and repr.
iterable = sorted(value, key=lambda item: (type(item).__name__, repr(item)))
for item in iterable:
_update_hash_for_value(hasher, item)
return
hasher.update(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL))


def _default_hash_func(args, kwds):
# Sort the kwargs to ensure consistent ordering
sorted_kwargs = sorted(kwds.items())
# Serialize args and sorted_kwargs using pickle or similar
serialized = pickle.dumps((args, sorted_kwargs))
# Create a hash of the serialized data
return hashlib.sha256(serialized).hexdigest()
"""Compute a stable hash key for function arguments.

Parameters
----------
args : tuple
Positional arguments.
kwds : dict
Keyword arguments.

Returns
-------
str
A hex digest representing the call arguments.

"""
hasher = hashlib.blake2b(digest_size=32)
hasher.update(b"args")
_update_hash_for_value(hasher, args)
hasher.update(b"kwds")
_update_hash_for_value(hasher, dict(sorted(kwds.items())))
return hasher.hexdigest()
Comment on lines 124 to +145
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the hash algorithm from SHA256 to blake2b is a breaking change that will invalidate all existing caches. When users upgrade to this version, their cached function results will not be found because the cache keys will be different. Consider documenting this breaking change in the PR description or adding a migration guide. Alternatively, consider versioning the hash function or providing a compatibility mode that can read old cache entries.

Copilot uses AI. Check for mistakes.


def _default_cache_dir():
Expand Down
49 changes: 49 additions & 0 deletions tests/test_numpy_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Tests for NumPy-aware default hash behavior."""

from datetime import timedelta

import pytest

from cachier import cachier

np = pytest.importorskip("numpy")


@pytest.mark.parametrize(
"backend",
[
pytest.param("memory", marks=pytest.mark.memory),
pytest.param("pickle", marks=pytest.mark.pickle),
],
)
def test_default_hash_func_uses_array_content_for_cache_keys(backend, tmp_path):
"""Verify equal arrays map to a cache hit and different arrays miss."""
call_count = 0

decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)}
if backend == "pickle":
decorator_kwargs["cache_dir"] = tmp_path

@cachier(**decorator_kwargs)
def array_sum(values):
nonlocal call_count
call_count += 1
return int(values.sum())

arr = np.arange(100_000, dtype=np.int64)
arr_copy = arr.copy()
changed = arr.copy()
changed[-1] = -1

first = array_sum(arr)
assert call_count == 1

second = array_sum(arr_copy)
assert second == first
assert call_count == 1

third = array_sum(changed)
assert third != first
assert call_count == 2

array_sum.clear_cache()
Loading