Skip to content

Commit 6e23658

Browse files
committed
VoyageAI refactoring:
- contextual model - removing the model default value - token counting, ie. more effective use of batches
1 parent 63dc6b2 commit 6e23658

File tree

3 files changed

+72
-214
lines changed

3 files changed

+72
-214
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ mistralai = ["mistralai>=1.0.0"]
3737
openai = ["openai>=1.1.0"]
3838
nltk = ["nltk>=3.8.1,<4"]
3939
cohere = ["cohere>=4.44"]
40-
voyageai = ["voyageai>=0.2.2"]
40+
voyageai = ["voyageai>=0.3.5"]
4141
sentence-transformers = ["sentence-transformers>=3.4.0,<4"]
4242
vertexai = [
4343
"google-cloud-aiplatform>=1.26,<2.0.0",

redisvl/utils/vectorize/text/voyageai.py

Lines changed: 18 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,15 @@ class VoyageAITextVectorizer(BaseVectorizer):
111111
# Token counting for API usage management
112112
token_counts = vectorizer.count_tokens(["text one", "text two"])
113113
print(f"Token counts: {token_counts}")
114-
print(f"Model token limit: {vectorizer.get_token_limit()}")
114+
print(f"Model token limit: {VOYAGE_TOTAL_TOKEN_LIMITS.get(vectorizer.model, 120_000)}")
115115
116116
"""
117117

118118
model_config = ConfigDict(arbitrary_types_allowed=True)
119119

120120
def __init__(
121121
self,
122-
model: str = "voyage-3.5",
122+
model: str,
123123
api_config: Optional[Dict] = None,
124124
dtype: str = "float32",
125125
cache: Optional["EmbeddingsCache"] = None,
@@ -130,7 +130,7 @@ def __init__(
130130
Visit https://docs.voyageai.com/docs/embeddings to learn about embeddings and check the available models.
131131
132132
Args:
133-
model (str): Model to use for embedding. Defaults to "voyage-3.5".
133+
model (str): Model to use for embedding (e.g., "voyage-3.5", "voyage-context-3").
134134
api_config (Optional[Dict], optional): Dictionary containing the API key.
135135
Defaults to None.
136136
dtype (str): the default datatype to use when embedding text as byte arrays.
@@ -213,22 +213,6 @@ def _set_model_dims(self) -> int:
213213
# fall back (TODO get more specific)
214214
raise ValueError(f"Error setting embedding model dimensions: {str(e)}")
215215

216-
def _get_batch_size(self) -> int:
217-
"""
218-
Determine the appropriate batch size based on the model being used.
219-
220-
Returns:
221-
int: Recommended batch size for the current model
222-
"""
223-
if self.model in ["voyage-2", "voyage-02"]:
224-
return 72
225-
elif self.model in ["voyage-3-lite", "voyage-3.5-lite"]:
226-
return 30
227-
elif self.model in ["voyage-3", "voyage-3.5"]:
228-
return 10
229-
else:
230-
return 7 # Default for other models
231-
232216
def _validate_input(
233217
self, texts: List[str], input_type: Optional[str], truncation: Optional[bool]
234218
):
@@ -285,16 +269,12 @@ def _embed_many(
285269
"""
286270
Generate vector embeddings for a batch of texts using the VoyageAI API.
287271
272+
Uses token-aware batching to respect model token limits and optimize API calls.
273+
288274
Args:
289275
texts: List of texts to embed
290-
batch_size: Number of texts to process in each API call.
291-
Ignored if use_token_batching=True.
276+
batch_size: Deprecated. Token-aware batching is now always used.
292277
**kwargs: Additional parameters to pass to the VoyageAI API.
293-
Special kwargs:
294-
- use_token_batching (bool): If True, use token-aware batching
295-
instead of simple batch_size-based batching. This respects
296-
model token limits and is recommended for large documents.
297-
Default: False.
298278
299279
Returns:
300280
List[List[float]]: List of vector embeddings as lists of floats
@@ -305,20 +285,12 @@ def _embed_many(
305285
"""
306286
input_type = kwargs.pop("input_type", None)
307287
truncation = kwargs.pop("truncation", None)
308-
use_token_batching = kwargs.pop("use_token_batching", False)
309288

310289
# Validate inputs
311290
self._validate_input(texts, input_type, truncation)
312291

313-
# Determine batching strategy
314-
if use_token_batching:
315-
# Use token-aware batching
316-
batches = self._build_token_aware_batches(texts, max_batch_size=1000)
317-
else:
318-
# Use simple batch_size-based batching
319-
if batch_size is None:
320-
batch_size = self._get_batch_size()
321-
batches = list(self.batchify(texts, batch_size))
292+
# Use token-aware batching
293+
batches = self._build_token_aware_batches(texts)
322294

323295
try:
324296
embeddings: List = []
@@ -342,10 +314,10 @@ def _embed_many(
342314
texts=batch,
343315
model=self.model,
344316
input_type=input_type,
345-
truncation=truncation,
317+
truncation=truncation, # type: ignore[assignment]
346318
**kwargs,
347319
)
348-
embeddings.extend(response.embeddings)
320+
embeddings.extend(response.embeddings) # type: ignore[attr-defined]
349321
return embeddings
350322
except Exception as e:
351323
raise ValueError(f"Embedding texts failed: {e}")
@@ -380,16 +352,12 @@ async def _aembed_many(
380352
"""
381353
Asynchronously generate vector embeddings for a batch of texts using the VoyageAI API.
382354
355+
Uses token-aware batching to respect model token limits and optimize API calls.
356+
383357
Args:
384358
texts: List of texts to embed
385-
batch_size: Number of texts to process in each API call.
386-
Ignored if use_token_batching=True.
359+
batch_size: Deprecated. Token-aware batching is now always used.
387360
**kwargs: Additional parameters to pass to the VoyageAI API.
388-
Special kwargs:
389-
- use_token_batching (bool): If True, use token-aware batching
390-
instead of simple batch_size-based batching. This respects
391-
model token limits and is recommended for large documents.
392-
Default: False.
393361
394362
Returns:
395363
List[List[float]]: List of vector embeddings as lists of floats
@@ -400,20 +368,12 @@ async def _aembed_many(
400368
"""
401369
input_type = kwargs.pop("input_type", None)
402370
truncation = kwargs.pop("truncation", None)
403-
use_token_batching = kwargs.pop("use_token_batching", False)
404371

405372
# Validate inputs
406373
self._validate_input(texts, input_type, truncation)
407374

408-
# Determine batching strategy
409-
if use_token_batching:
410-
# Use token-aware batching
411-
batches = await self._abuild_token_aware_batches(texts, max_batch_size=1000)
412-
else:
413-
# Use simple batch_size-based batching
414-
if batch_size is None:
415-
batch_size = self._get_batch_size()
416-
batches = list(self.batchify(texts, batch_size))
375+
# Use token-aware batching (synchronous - tokenization is sync-only)
376+
batches = self._build_token_aware_batches(texts)
417377

418378
try:
419379
embeddings: List = []
@@ -437,10 +397,10 @@ async def _aembed_many(
437397
texts=batch,
438398
model=self.model,
439399
input_type=input_type,
440-
truncation=truncation,
400+
truncation=truncation, # type: ignore[assignment]
441401
**kwargs,
442402
)
443-
embeddings.extend(response.embeddings)
403+
embeddings.extend(response.embeddings) # type: ignore[attr-defined]
444404
return embeddings
445405
except Exception as e:
446406
raise ValueError(f"Embedding texts failed: {e}")
@@ -473,48 +433,6 @@ def count_tokens(self, texts: List[str]) -> List[int]:
473433
except Exception as e:
474434
raise ValueError(f"Token counting failed: {e}")
475435

476-
async def acount_tokens(self, texts: List[str]) -> List[int]:
477-
"""
478-
Asynchronously count tokens for the given texts using VoyageAI's tokenization API.
479-
480-
Args:
481-
texts: List of texts to count tokens for.
482-
483-
Returns:
484-
List[int]: List of token counts for each text.
485-
486-
Raises:
487-
ValueError: If tokenization fails.
488-
489-
Example:
490-
>>> vectorizer = VoyageAITextVectorizer(model="voyage-3.5")
491-
>>> token_counts = await vectorizer.acount_tokens(["Hello world", "Another text"])
492-
>>> print(token_counts) # [2, 2]
493-
"""
494-
if not texts:
495-
return []
496-
497-
try:
498-
# Use the VoyageAI async tokenize API to get token counts
499-
token_lists = await self._aclient.tokenize(texts, model=self.model)
500-
return [len(token_list) for token_list in token_lists]
501-
except Exception as e:
502-
raise ValueError(f"Token counting failed: {e}")
503-
504-
def get_token_limit(self) -> int:
505-
"""
506-
Get the total token limit for the current model.
507-
508-
Returns:
509-
int: Token limit for the model, or default of 120_000 if not found.
510-
511-
Example:
512-
>>> vectorizer = VoyageAITextVectorizer(model="voyage-context-3")
513-
>>> limit = vectorizer.get_token_limit()
514-
>>> print(limit) # 32000
515-
"""
516-
return VOYAGE_TOTAL_TOKEN_LIMITS.get(self.model, 120_000)
517-
518436
def _is_context_model(self) -> bool:
519437
"""
520438
Check if the current model is a contextualized embedding model.
@@ -550,7 +468,7 @@ def _build_token_aware_batches(
550468
if not texts:
551469
return []
552470

553-
max_tokens_per_batch = self.get_token_limit()
471+
max_tokens_per_batch = VOYAGE_TOTAL_TOKEN_LIMITS.get(self.model, 120_000)
554472
batches = []
555473
current_batch: List[str] = []
556474
current_batch_tokens = 0
@@ -583,62 +501,6 @@ def _build_token_aware_batches(
583501

584502
return batches
585503

586-
async def _abuild_token_aware_batches(
587-
self, texts: List[str], max_batch_size: int = 1000
588-
) -> List[List[str]]:
589-
"""
590-
Asynchronously generate batches of texts based on token limits and batch size constraints.
591-
592-
This method uses VoyageAI's tokenization API to count tokens for all texts
593-
in a single call, then creates batches that respect both the model's token
594-
limit and a maximum batch size.
595-
596-
Args:
597-
texts: List of texts to batch.
598-
max_batch_size: Maximum number of texts per batch (default: 1000).
599-
600-
Returns:
601-
List[List[str]]: List of batches, where each batch is a list of texts.
602-
603-
Raises:
604-
ValueError: If tokenization fails.
605-
"""
606-
if not texts:
607-
return []
608-
609-
max_tokens_per_batch = self.get_token_limit()
610-
batches = []
611-
current_batch: List[str] = []
612-
current_batch_tokens = 0
613-
614-
# Tokenize all texts in one API call for efficiency
615-
try:
616-
token_counts = await self.acount_tokens(texts)
617-
except Exception as e:
618-
raise ValueError(f"Failed to count tokens for batching: {e}")
619-
620-
for i, text in enumerate(texts):
621-
n_tokens = token_counts[i]
622-
623-
# Check if adding this text would exceed limits
624-
if current_batch and (
625-
len(current_batch) >= max_batch_size
626-
or (current_batch_tokens + n_tokens > max_tokens_per_batch)
627-
):
628-
# Save the current batch and start a new one
629-
batches.append(current_batch)
630-
current_batch = []
631-
current_batch_tokens = 0
632-
633-
current_batch.append(text)
634-
current_batch_tokens += n_tokens
635-
636-
# Add the last batch if it has any texts
637-
if current_batch:
638-
batches.append(current_batch)
639-
640-
return batches
641-
642504
@property
643505
def type(self) -> str:
644506
return "voyageai"

0 commit comments

Comments
 (0)