Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
__pycache__/
dist/
poetry.toml
.venv/
125 changes: 125 additions & 0 deletions src/cohere/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,131 @@ def embed(
)
return _response.data

def embed_stream(
self,
*,
texts: typing.Optional[typing.Sequence[str]] = OMIT,
model: typing.Optional[str] = OMIT,
input_type: typing.Optional[EmbedInputType] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
batch_size: int = 10,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.Iterator[typing.Any]: # Returns Iterator[StreamedEmbedding]
"""
Memory-efficient streaming version of embed that yields embeddings one at a time.

This method processes texts in batches and yields individual embeddings as they are
parsed from the response, without loading all embeddings into memory at once.
Ideal for processing large datasets where memory usage is a concern.

Parameters
----------
texts : typing.Optional[typing.Sequence[str]]
An array of strings for the model to embed. Will be processed in batches.

model : typing.Optional[str]
ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed).

input_type : typing.Optional[EmbedInputType]
Specifies the type of input passed to the model.

embedding_types : typing.Optional[typing.Sequence[EmbeddingType]]
Specifies the types of embeddings you want to get back.

truncate : typing.Optional[EmbedRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.

batch_size : int
Number of texts to process in each batch. Default is 10.
Lower values use less memory but may be slower overall.

request_options : typing.Optional[RequestOptions]
Request-specific configuration.

Yields
------
StreamedEmbedding
Individual embeddings as they are parsed from the response.

Examples
--------
from cohere import Client

client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)

# Process embeddings one at a time without loading all into memory
for embedding in client.embed_stream(
texts=["hello", "goodbye", "how are you"],
model="embed-v4.0",
batch_size=2
):
print(f"Embedding {embedding.index}: {embedding.embedding[:5]}...")
# Process/save embedding immediately
"""
# Validate batch_size
if batch_size < 1:
raise ValueError("batch_size must be at least 1")

# Handle OMIT sentinel and empty texts
if texts is None or texts is OMIT:
return
if not texts:
return

from .streaming_utils import StreamingEmbedParser

# Process texts in batches
texts_list = list(texts)

for batch_start in range(0, len(texts_list), batch_size):
batch_end = min(batch_start + batch_size, len(texts_list))
batch_texts = texts_list[batch_start:batch_end]

# Get response for this batch
response = self._raw_client.embed(
texts=batch_texts,
model=model,
input_type=input_type,
embedding_types=embedding_types,
truncate=truncate,
request_options=request_options,
)

# Parse embeddings from response incrementally
parser = StreamingEmbedParser(response._response, batch_texts)
# Track used indices per embedding type to handle:
# 1. Duplicate texts within a batch
# 2. Multiple embedding types (float, int8, etc.) for the same texts
used_batch_indices_by_type: dict[str, set[int]] = {}

for embedding in parser.iter_embeddings():
# The parser sets embedding.text correctly for multiple embedding types
# Adjust the global index based on text position in batch
# Use 'is not None' to handle empty strings correctly (they are falsy but valid)
if embedding.text is not None and embedding.text in batch_texts:
# Get or create the set of used indices for this embedding type
emb_type = embedding.embedding_type
if emb_type not in used_batch_indices_by_type:
used_batch_indices_by_type[emb_type] = set()
used_indices = used_batch_indices_by_type[emb_type]

# Find the next unused occurrence of this text in the batch
# This handles duplicate texts correctly
text_idx_in_batch = None
for idx, text in enumerate(batch_texts):
if text == embedding.text and idx not in used_indices:
text_idx_in_batch = idx
used_indices.add(idx)
break

if text_idx_in_batch is not None:
embedding.index = batch_start + text_idx_in_batch
yield embedding

def rerank(
self,
*,
Expand Down
93 changes: 64 additions & 29 deletions src/cohere/client.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
import asyncio
import logging
import os
import typing
from concurrent.futures import ThreadPoolExecutor
from tokenizers import Tokenizer # type: ignore
import logging

import httpx

from cohere.types.detokenize_response import DetokenizeResponse
from cohere.types.tokenize_response import TokenizeResponse

from . import EmbedResponse, EmbedInputType, EmbeddingType, EmbedRequestTruncate
from .base_client import BaseCohere, AsyncBaseCohere, OMIT
from . import EmbeddingType, EmbedInputType, EmbedRequestTruncate, EmbedResponse
from .base_client import OMIT, AsyncBaseCohere, BaseCohere
from .config import embed_batch_size
from .core import RequestOptions
from .environment import ClientEnvironment
from .manually_maintained.cache import CacheMixin
from .manually_maintained import tokenizers as local_tokenizers
from .manually_maintained.cache import CacheMixin
from .overrides import run_overrides
from .utils import wait, async_wait, merge_embed_responses, SyncSdkUtils, AsyncSdkUtils
from .utils import AsyncSdkUtils, SyncSdkUtils, async_wait, merge_embed_responses, wait
from tokenizers import Tokenizer # type: ignore

from cohere.types.detokenize_response import DetokenizeResponse
from cohere.types.tokenize_response import TokenizeResponse

logger = logging.getLogger(__name__)
run_overrides()
Expand Down Expand Up @@ -188,6 +187,8 @@ def embed(
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
batching: typing.Optional[bool] = True,
batch_size: typing.Optional[int] = None,
max_workers: typing.Optional[int] = None,
) -> EmbedResponse:
# skip batching for images for now
if batching is False or images is not OMIT:
Expand All @@ -202,24 +203,39 @@ def embed(
request_options=request_options,
)

# Validate batch_size
if batch_size is not None and batch_size < 1:
raise ValueError("batch_size must be at least 1")

textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else []
texts_batches = [textsarr[i : i + embed_batch_size] for i in range(0, len(textsarr), embed_batch_size)]

responses = [
response
for response in self._executor.map(
lambda text_batch: BaseCohere.embed(
self,
texts=text_batch,
model=model,
input_type=input_type,
embedding_types=embedding_types,
truncate=truncate,
request_options=request_options,
),
texts_batches,
)
]
effective_batch_size = batch_size if batch_size is not None else embed_batch_size
texts_batches = [textsarr[i : i + effective_batch_size] for i in range(0, len(textsarr), effective_batch_size)]

# Use custom executor if max_workers is specified
executor = self._executor
if max_workers is not None:
executor = ThreadPoolExecutor(max_workers=max_workers)

try:
responses = [
response
for response in executor.map(
lambda text_batch: BaseCohere.embed(
self,
texts=text_batch,
model=model,
input_type=input_type,
embedding_types=embedding_types,
truncate=truncate,
request_options=request_options,
),
texts_batches,
)
]
finally:
# Clean up custom executor if created
if max_workers is not None:
executor.shutdown(wait=False)

return merge_embed_responses(responses)

Expand Down Expand Up @@ -380,6 +396,8 @@ async def embed(
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
batching: typing.Optional[bool] = True,
batch_size: typing.Optional[int] = None,
max_workers: typing.Optional[int] = None,
) -> EmbedResponse:
# skip batching for images for now
if batching is False or images is not OMIT:
Expand All @@ -394,9 +412,26 @@ async def embed(
request_options=request_options,
)

textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else []
texts_batches = [textsarr[i : i + embed_batch_size] for i in range(0, len(textsarr), embed_batch_size)]
# Validate batch_size
if batch_size is not None and batch_size < 1:
raise ValueError("batch_size must be at least 1")

textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else []
effective_batch_size = batch_size if batch_size is not None else embed_batch_size
texts_batches = [textsarr[i : i + effective_batch_size] for i in range(0, len(textsarr), effective_batch_size)]

# Note: max_workers parameter is not applicable to async version since asyncio.gather
# handles concurrency differently than ThreadPoolExecutor
if max_workers is not None:
import warnings
warnings.warn(
"The 'max_workers' parameter is not supported for AsyncClient. "
"Async clients use asyncio.gather() for concurrent execution, which "
"automatically manages concurrency. The parameter will be ignored.",
UserWarning,
stacklevel=2
)

responses = typing.cast(
typing.List[EmbedResponse],
await asyncio.gather(
Expand Down
Loading