Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ Optimizations
* GITHUB#15474: Use bulk scoring provided by RandomVectorScorers for new scalar quantized formats provided through
Lucene104ScalarQuantizedVectorsFormat and Lucene104HnswScalarQuantizedVectorsFormat (Ben Trent)

* GITHUB#15500: Use bulk scoring for filtered HNSW search and for entry-point scoring in the graph. This should
provide speed improvements when using vector scorers that satisfy the bulk scoring interface. (Ben Trent)

Bug Fixes
---------------------
* GITHUB#14161: PointInSetQuery's constructor now throws IllegalArgumentException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.io.IOException;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;

/**
Expand Down Expand Up @@ -81,4 +82,38 @@ public void search(
}
searchLevel(results, scorer, 0, eps, graph, acceptOrds);
}

protected static void scoreEntryPoints(
KnnCollector results,
RandomVectorScorer scorer,
BitSet visited,
int[] eps,
Bits acceptOrds,
NeighborQueue candidates,
float[] scores)
throws IOException {
assert eps != null && eps.length > 0;
assert scores != null && scores.length >= eps.length;
if (eps.length == 1) {
visited.set(eps[0]);
float score = scorer.score(eps[0]);
results.incVisitedCount(1);
candidates.add(eps[0], score);
if (acceptOrds == null || acceptOrds.get(eps[0])) {
results.collect(eps[0], score);
}
} else {
scorer.bulkScore(eps, scores, eps.length);
Copy link
Contributor

@john-wagster john-wagster Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Saying this outloud in case my assumption is wrong. I assume the reason we don't need benchmarking here is that we know from prior work at the leaf level that there's definitely a benefit to bulk scoring eps here instead of doing an early termination check for each entry point.

results.incVisitedCount(eps.length);
for (int i = 0; i < eps.length; i++) {
float score = scores[i];
int ep = eps[i];
visited.set(ep);
candidates.add(ep, score);
if (acceptOrds == null || acceptOrds.get(ep)) {
results.collect(ep, score);
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@

/**
* Searches an HNSW graph to find nearest neighbors to a query vector. This particular
* implementation is optimized for a filtered search, inspired by the ACORN-1 algorithm.
* https://arxiv.org/abs/2403.04871 However, this implementation is augmented in some ways, mainly:
* implementation is optimized for a filtered search, inspired by the <a
* href="https://arxiv.org/abs/2403.04871">ACORN-1 algorithm</a>. However, this implementation is
* augmented in some ways, mainly:
*
* <ul>
* <li>It dynamically determines when the optimized filter step should occur based on some
Expand Down Expand Up @@ -114,18 +115,15 @@ void searchLevel(

prepareScratchState();

for (int ep : eps) {
if (visited.getAndSet(ep) == false) {
if (results.earlyTerminated()) {
return;
}
float score = scorer.score(ep);
results.incVisitedCount(1);
candidates.add(ep, score);
if (acceptOrds.get(ep)) {
results.collect(ep, score);
}
}
if (bulkScores == null || bulkScores.length < eps.length) {
bulkScores = new float[eps.length];
}
if (results.earlyTerminated()) {
return;
}
scoreEntryPoints(results, scorer, visited, eps, acceptOrds, candidates, bulkScores);
if (results.earlyTerminated()) {
return;
}
// Collect the vectors to score and potentially add as candidates
IntArrayQueue toScore = new IntArrayQueue(graph.maxConn() * 2 * maxExplorationMultiplier);
Expand Down Expand Up @@ -190,17 +188,29 @@ void searchLevel(
}
}
// Score the vectors and add them to the candidate list
int toScoreOrd;
while ((toScoreOrd = toScore.poll()) != NO_MORE_DOCS) {
float friendSimilarity = scorer.score(toScoreOrd);
results.incVisitedCount(1);
if (friendSimilarity > minAcceptedSimilarity) {
candidates.add(toScoreOrd, friendSimilarity);
if (results.collect(toScoreOrd, friendSimilarity)) {
minAcceptedSimilarity = Math.nextUp(results.minCompetitiveSimilarity());
if (bulkScores == null || bulkScores.length < toScore.count()) {
bulkScores = new float[toScore.count()];
}
assert toScore.upto == 0;
float maxScore =
toScore.count() > 0
? scorer.bulkScore(toScore.nodes, bulkScores, toScore.size)
: Float.NEGATIVE_INFINITY;
results.incVisitedCount(toScore.count());
if (maxScore > minAcceptedSimilarity) {
for (int i = 0; i < toScore.count(); i++) {
int idx = i + toScore.upto;
float friendSimilarity = bulkScores[idx];
if (friendSimilarity > minAcceptedSimilarity) {
int ord = toScore.nodes[idx];
candidates.add(ord, friendSimilarity);
if (results.collect(ord, friendSimilarity)) {
minAcceptedSimilarity = Math.nextUp(results.minCompetitiveSimilarity());
}
}
}
}
toScore.upto = toScore.size; // all scored
if (results.getSearchStrategy() != null) {
results.getSearchStrategy().nextVectorsBlock();
}
Expand All @@ -213,7 +223,7 @@ private void prepareScratchState() {
}

private static class IntArrayQueue {
private int[] nodes;
private final int[] nodes;
private int upto;
private int size;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, KnnCollecto
return new int[] {currentEp};
}
int size = getGraphSize(graph);
prepareScratchState(size);
prepareScratchState(size, graph.maxConn() * 2);
float currentScore = scorer.score(currentEp);
collector.incVisitedCount(1);
boolean foundBetter;
Expand All @@ -238,6 +238,7 @@ int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, KnnCollecto
foundBetter = false;
graphSeek(graph, level, currentEp);
int friendOrd;
int numNodes = 0;
while ((friendOrd = graphNextNeighbor(graph)) != NO_MORE_DOCS) {
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
if (visited.getAndSet(friendOrd)) {
Expand All @@ -246,12 +247,21 @@ int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, KnnCollecto
if (collector.earlyTerminated()) {
return new int[] {UNK_EP};
}
float friendSimilarity = scorer.score(friendOrd);
collector.incVisitedCount(1);
if (friendSimilarity > currentScore) {
currentScore = friendSimilarity;
currentEp = friendOrd;
foundBetter = true;
bulkNodes[numNodes++] = friendOrd;
}
float maxScore =
numNodes > 0
? scorer.bulkScore(bulkNodes, bulkScores, numNodes)
: Float.NEGATIVE_INFINITY;
collector.incVisitedCount(numNodes);
if (maxScore > currentScore) {
for (int i = 0; i < numNodes; i++) {
float score = bulkScores[i];
if (score > currentScore) {
currentScore = score;
currentEp = bulkNodes[i];
foundBetter = true;
}
}
}
}
Expand All @@ -277,25 +287,16 @@ void searchLevel(

int size = getGraphSize(graph);

prepareScratchState(size);

if (bulkNodes == null || bulkNodes.length < graph.maxConn() * 2) {
bulkNodes = new int[graph.maxConn() * 2];
bulkScores = new float[graph.maxConn() * 2];
prepareScratchState(size, graph.maxConn() * 2);
if (bulkScores == null || bulkScores.length < eps.length) {
bulkScores = new float[eps.length];
}

for (int ep : eps) {
if (visited.getAndSet(ep) == false) {
if (results.earlyTerminated()) {
break;
}
float score = scorer.score(ep);
results.incVisitedCount(1);
candidates.add(ep, score);
if (acceptOrds == null || acceptOrds.get(ep)) {
results.collect(ep, score);
}
}
if (results.earlyTerminated()) {
return;
}
scoreEntryPoints(results, scorer, visited, eps, acceptOrds, candidates, bulkScores);
if (results.earlyTerminated()) {
return;
}

// A bound that holds the minimum similarity to the query vector that a candidate vector must
Expand Down Expand Up @@ -335,7 +336,7 @@ void searchLevel(
bulkNodes[numNodes++] = friendOrd;
}

numNodes = (int) Math.min((long) numNodes, results.visitLimit() - results.visitedCount());
numNodes = (int) Math.min(numNodes, results.visitLimit() - results.visitedCount());
results.incVisitedCount(numNodes);
if (numNodes > 0
&& scorer.bulkScore(bulkNodes, bulkScores, numNodes)
Expand Down Expand Up @@ -365,12 +366,16 @@ void searchLevel(
}
}

private void prepareScratchState(int capacity) {
private void prepareScratchState(int capacity, int bulkScoreSize) {
candidates.clear();
if (visited.length() < capacity) {
visited = FixedBitSet.ensureCapacity((FixedBitSet) visited, capacity);
}
visited.clear();
if (bulkNodes == null || bulkNodes.length < bulkScoreSize) {
bulkNodes = new int[bulkScoreSize];
bulkScores = new float[bulkScoreSize];
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1503,10 +1503,13 @@ public void testSearchWithVisitedLimit() throws Exception {
visitedLimit);
assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, results.totalHits.relation());
int size = Lucene99HnswVectorsReader.EXHAUSTIVE_BULK_SCORE_ORDS;
// visit limit is a "best effort" limit given our bulk scoring logic; assert that we are
// within
// reasonable bounds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: clean up spotless multi-line comment

assertTrue(
visitedLimit == results.totalHits.value()
|| ((visitedLimit + size - 1) / size) * ((long) size)
== results.totalHits.value());
results.totalHits.value() == visitedLimit
|| results.totalHits.value()
<= ((visitedLimit + size - 1) / size) * ((long) size));

// check the limit is not hit when it clearly exceeds the number of vectors
k = vectorValues.size();
Expand Down