Skip to content

Commit 04461cd

Browse files
committed
Add token counting and context model detection to VoyageAI
1 parent ff415fb commit 04461cd

File tree

2 files changed

+176
-0
lines changed

2 files changed

+176
-0
lines changed

redisvl/utils/vectorize/voyageai.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,26 @@
1313
# ignore that voyageai isn't imported
1414
# mypy: disable-error-code="name-defined"
1515

16+
# Token limits for VoyageAI models (used for token-aware batching)
17+
VOYAGE_TOTAL_TOKEN_LIMITS = {
18+
"voyage-context-3": 32_000,
19+
"voyage-3.5-lite": 1_000_000,
20+
"voyage-3.5": 320_000,
21+
"voyage-2": 320_000,
22+
"voyage-3-large": 120_000,
23+
"voyage-code-3": 120_000,
24+
"voyage-large-2-instruct": 120_000,
25+
"voyage-finance-2": 120_000,
26+
"voyage-multilingual-2": 120_000,
27+
"voyage-law-2": 120_000,
28+
"voyage-large-2": 120_000,
29+
"voyage-3": 120_000,
30+
"voyage-3-lite": 120_000,
31+
"voyage-code-2": 120_000,
32+
"voyage-multimodal-3": 32_000,
33+
"voyage-multimodal-3.5": 32_000,
34+
}
35+
1636

1737
class VoyageAIVectorizer(BaseVectorizer):
1838
"""The VoyageAIVectorizer class utilizes VoyageAI's API to generate
@@ -87,6 +107,21 @@ class VoyageAIVectorizer(BaseVectorizer):
87107
input_type="query"
88108
)
89109
110+
# Using contextualized embeddings (voyage-context-3)
111+
context_vectorizer = VoyageAIVectorizer(
112+
model="voyage-context-3",
113+
api_config={"api_key": "your-voyageai-api-key"}
114+
)
115+
# Context models automatically use contextualized_embed API
116+
context_embeddings = context_vectorizer.embed_many(
117+
contents=["chunk 1", "chunk 2", "chunk 3"],
118+
input_type="document"
119+
)
120+
121+
# Token counting for API usage management
122+
token_counts = vectorizer.count_tokens(["text one", "text two"])
123+
print(f"Token counts: {token_counts}")
124+
90125
"""
91126

92127
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -448,6 +483,80 @@ def _serialize_for_cache(self, content: Any) -> Union[bytes, str]:
448483
return content.to_bytes()
449484
return super()._serialize_for_cache(content)
450485

486+
def _is_context_model(self) -> bool:
487+
"""
488+
Check if the current model is a contextualized embedding model.
489+
490+
Contextualized models (like voyage-context-3) use a different API
491+
endpoint and expect inputs formatted differently.
492+
493+
Returns:
494+
bool: True if the model is a context model, False otherwise.
495+
"""
496+
return "context" in self.model
497+
498+
def count_tokens(self, texts: List[str]) -> List[int]:
499+
"""
500+
Count tokens for the given texts using VoyageAI's tokenization API.
501+
502+
This is useful for managing API usage and optimizing batching strategies.
503+
504+
Args:
505+
texts: List of texts to count tokens for.
506+
507+
Returns:
508+
List[int]: List of token counts for each text.
509+
510+
Raises:
511+
ValueError: If tokenization fails.
512+
513+
Example:
514+
>>> vectorizer = VoyageAIVectorizer(model="voyage-3.5")
515+
>>> token_counts = vectorizer.count_tokens(["Hello world", "Another text"])
516+
>>> print(token_counts) # [2, 2]
517+
"""
518+
if not texts:
519+
return []
520+
521+
try:
522+
token_lists = self._client.tokenize(texts, model=self.model)
523+
return [len(token_list) for token_list in token_lists]
524+
except Exception as e:
525+
raise ValueError(f"Token counting failed: {e}")
526+
527+
async def acount_tokens(self, texts: List[str]) -> List[int]:
528+
"""
529+
Asynchronously count tokens for the given texts using VoyageAI's tokenization API.
530+
531+
This is useful for managing API usage and optimizing batching strategies.
532+
533+
Note: The underlying VoyageAI tokenize API is synchronous, so this method
534+
provides async compatibility but doesn't offer true async performance benefits.
535+
536+
Args:
537+
texts: List of texts to count tokens for.
538+
539+
Returns:
540+
List[int]: List of token counts for each text.
541+
542+
Raises:
543+
ValueError: If tokenization fails.
544+
545+
Example:
546+
>>> vectorizer = VoyageAIVectorizer(model="voyage-3.5")
547+
>>> token_counts = await vectorizer.acount_tokens(["Hello world", "Another text"])
548+
>>> print(token_counts) # [2, 2]
549+
"""
550+
if not texts:
551+
return []
552+
553+
try:
554+
# Note: VoyageAI's tokenize is synchronous even on AsyncClient
555+
token_lists = self._aclient.tokenize(texts, model=self.model)
556+
return [len(token_list) for token_list in token_lists]
557+
except Exception as e:
558+
raise ValueError(f"Token counting failed: {e}")
559+
451560
@property
452561
def type(self) -> str:
453562
return "voyageai"

tests/integration/test_vectorizers.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,3 +629,70 @@ def test_deprecated_text_parameter_warning():
629629
embeddings = vectorizer.embed_many(texts=TEST_TEXTS)
630630
assert isinstance(embeddings, list)
631631
assert len(embeddings) == len(TEST_TEXTS)
632+
633+
634+
# VoyageAI-specific tests for token counting and context model detection
635+
@pytest.mark.requires_api_keys
636+
def test_voyageai_count_tokens():
637+
"""Test VoyageAI token counting functionality."""
638+
vectorizer = VoyageAIVectorizer(model="voyage-3.5")
639+
texts = ["Hello world", "This is a longer test sentence."]
640+
641+
token_counts = vectorizer.count_tokens(texts)
642+
assert isinstance(token_counts, list)
643+
assert len(token_counts) == len(texts)
644+
assert all(isinstance(count, int) and count > 0 for count in token_counts)
645+
646+
# Empty list should return empty list
647+
assert vectorizer.count_tokens([]) == []
648+
649+
650+
@pytest.mark.requires_api_keys
651+
@pytest.mark.asyncio
652+
async def test_voyageai_acount_tokens():
653+
"""Test VoyageAI async token counting functionality."""
654+
vectorizer = VoyageAIVectorizer(model="voyage-3.5")
655+
texts = ["Hello world", "This is a longer test sentence."]
656+
657+
token_counts = await vectorizer.acount_tokens(texts)
658+
assert isinstance(token_counts, list)
659+
assert len(token_counts) == len(texts)
660+
assert all(isinstance(count, int) and count > 0 for count in token_counts)
661+
662+
# Empty list should return empty list
663+
assert await vectorizer.acount_tokens([]) == []
664+
665+
666+
def test_voyageai_token_limits():
667+
"""Test VoyageAI token limit constants."""
668+
from redisvl.utils.vectorize.voyageai import VOYAGE_TOTAL_TOKEN_LIMITS
669+
670+
# Verify token limits are defined correctly
671+
assert VOYAGE_TOTAL_TOKEN_LIMITS.get("voyage-context-3") == 32_000
672+
assert VOYAGE_TOTAL_TOKEN_LIMITS.get("voyage-3.5-lite") == 1_000_000
673+
assert VOYAGE_TOTAL_TOKEN_LIMITS.get("voyage-3.5") == 320_000
674+
assert VOYAGE_TOTAL_TOKEN_LIMITS.get("voyage-multimodal-3") == 32_000
675+
assert VOYAGE_TOTAL_TOKEN_LIMITS.get("voyage-multimodal-3.5") == 32_000
676+
677+
# Default for unknown models
678+
assert VOYAGE_TOTAL_TOKEN_LIMITS.get("unknown-model", 120_000) == 120_000
679+
680+
681+
def test_voyageai_context_model_detection():
682+
"""Test detection of contextualized embedding models."""
683+
# Test the context model detection logic directly
684+
# The method checks if "context" is in the model name
685+
assert "context" not in "voyage-3.5"
686+
assert "context" in "voyage-context-3"
687+
assert "context" not in "voyage-multimodal-3.5"
688+
689+
# Verify the detection would work correctly for known models
690+
test_cases = [
691+
("voyage-3.5", False),
692+
("voyage-context-3", True),
693+
("voyage-multimodal-3.5", False),
694+
("voyage-3-large", False),
695+
]
696+
for model_name, expected in test_cases:
697+
# The _is_context_model method simply checks: "context" in self.model
698+
assert ("context" in model_name) == expected, f"Failed for {model_name}"

0 commit comments

Comments
 (0)