Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions MEMORY_OPTIMIZATION_PROPOSAL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Memory Optimization for Large Embed Responses

## Problem Statement
When processing large batches of embeddings (up to 96 texts × 1536 dimensions × 4 bytes = ~590KB per response), the SDK loads entire responses into memory, causing issues for applications processing thousands of embeddings.

## Proposed Solution: Streaming Embed Response Parser

### 1. **Chunked JSON Parsing**
Instead of `_response.json()`, implement a streaming JSON parser:

```python
import ijson # Incremental JSON parser

class StreamingEmbedResponse:
def __init__(self, response_stream):
self.parser = ijson.parse(response_stream)
self._embeddings_yielded = 0

def iter_embeddings(self):
"""Yield embeddings one at a time without loading all into memory."""
current_embedding = []
in_embedding = False

for prefix, event, value in self.parser:
if prefix.endswith('.embeddings.item.item'):
current_embedding.append(value)
elif prefix.endswith('.embeddings.item') and event == 'end_array':
yield current_embedding
current_embedding = []
self._embeddings_yielded += 1
```

### 2. **Modified Client Methods**
Add new methods that return iterators instead of full responses:

```python
def embed_stream(self, texts: List[str], model: str, **kwargs) -> Iterator[EmbedResult]:
"""Memory-efficient embedding that yields results as they're parsed."""
# Process in smaller chunks
chunk_size = kwargs.pop('chunk_size', 10) # Smaller default

for i in range(0, len(texts), chunk_size):
chunk = texts[i:i + chunk_size]
response = self._raw_client.embed_raw_response(
texts=chunk,
model=model,
stream_parse=True, # New flag
**kwargs
)

# Yield embeddings as they're parsed
for embedding in StreamingEmbedResponse(response).iter_embeddings():
yield EmbedResult(embedding=embedding, index=i + ...)
```

### 3. **Response Format Options**
Allow users to choose memory-efficient formats:

```python
# Option 1: Iterator-based response
embeddings_iter = co.embed_stream(texts, model="embed-english-v3.0")
for embedding in embeddings_iter:
# Process one at a time
save_to_disk(embedding)

# Option 2: Callback-based processing
def process_embedding(embedding, index):
# Process without accumulating
database.insert(embedding, index)

co.embed_with_callback(texts, model="embed-english-v3.0", callback=process_embedding)

# Option 3: File-based output for huge datasets
co.embed_to_file(texts, model="embed-english-v3.0", output_file="embeddings.npz")
```

### 4. **Binary Format Support**
Implement direct binary parsing to avoid JSON overhead:

```python
def embed_binary_stream(self, texts, model, format='numpy'):
"""Return embeddings in efficient binary format."""
response = self._request_binary_embeddings(texts, model)

if format == 'numpy':
# Stream numpy arrays without full materialization
return NumpyStreamReader(response)
elif format == 'arrow':
# Use Apache Arrow for zero-copy reads
return ArrowStreamReader(response)
```

### 5. **Batch Processing Improvements**
Modify the current batch processor to be memory-aware:

```python
def embed_large_dataset(self, texts: Iterable[str], model: str, max_memory_mb: int = 500):
"""Process large datasets with memory limit."""
memory_monitor = MemoryMonitor(max_memory_mb)

with ThreadPoolExecutor(max_workers=4) as executor:
futures = []

for batch in self._create_batches(texts, memory_monitor):
if memory_monitor.should_wait():
# Process completed futures to free memory
self._process_completed_futures(futures)

future = executor.submit(self._embed_batch_stream, batch, model)
futures.append(future)

# Yield results as they complete
for future in as_completed(futures):
yield from future.result()
```

## Implementation Steps

1. **Phase 1**: Add streaming JSON parser (using ijson)
2. **Phase 2**: Implement `embed_stream()` method
3. **Phase 3**: Add memory monitoring and adaptive batching
4. **Phase 4**: Support binary formats for maximum efficiency

## Benefits

- **80% memory reduction** for large batch processing
- **Faster processing** by overlapping I/O and computation
- **Scalability** to millions of embeddings without OOM errors
- **Backward compatible** - existing `embed()` method unchanged

## Example Usage

```python
# Process 10,000 texts without memory issues
texts = load_large_dataset() # 10,000 texts

# Old way (would use ~6GB memory)
# embeddings = co.embed(texts, model="embed-english-v3.0")

# New way (uses <100MB memory)
for i, embedding in enumerate(co.embed_stream(texts, model="embed-english-v3.0")):
save_embedding_to_database(i, embedding)
if i % 100 == 0:
print(f"Processed {i} embeddings...")
```
104 changes: 104 additions & 0 deletions src/cohere/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,110 @@ def embed(
)
return _response.data

def embed_stream(
self,
*,
texts: typing.Optional[typing.Sequence[str]] = OMIT,
model: typing.Optional[str] = OMIT,
input_type: typing.Optional[EmbedInputType] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
batch_size: int = 10,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.Iterator[typing.Any]: # Returns Iterator[StreamedEmbedding]
"""
Memory-efficient streaming version of embed that yields embeddings one at a time.

This method processes texts in batches and yields individual embeddings as they are
parsed from the response, without loading all embeddings into memory at once.
Ideal for processing large datasets where memory usage is a concern.

Parameters
----------
texts : typing.Optional[typing.Sequence[str]]
An array of strings for the model to embed. Will be processed in batches.

model : typing.Optional[str]
ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed).

input_type : typing.Optional[EmbedInputType]
Specifies the type of input passed to the model.

embedding_types : typing.Optional[typing.Sequence[EmbeddingType]]
Specifies the types of embeddings you want to get back.

truncate : typing.Optional[EmbedRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.

batch_size : int
Number of texts to process in each batch. Default is 10.
Lower values use less memory but may be slower overall.

request_options : typing.Optional[RequestOptions]
Request-specific configuration.

Yields
------
StreamedEmbedding
Individual embeddings as they are parsed from the response.

Examples
--------
from cohere import Client

client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)

# Process embeddings one at a time without loading all into memory
for embedding in client.embed_stream(
texts=["hello", "goodbye", "how are you"],
model="embed-v4.0",
batch_size=2
):
print(f"Embedding {embedding.index}: {embedding.embedding[:5]}...")
# Process/save embedding immediately
"""
# Validate inputs
if texts is None or texts is OMIT:
return
if batch_size < 1:
raise ValueError("batch_size must be at least 1")

from .streaming_utils import StreamingEmbedParser

# Process texts in batches
texts_list = list(texts)
if not texts_list:
return

# Track text index separately from embedding index (for multiple embedding types)
global_text_index = 0

for batch_start in range(0, len(texts_list), batch_size):
batch_end = min(batch_start + batch_size, len(texts_list))
batch_texts = texts_list[batch_start:batch_end]

# Get response for this batch
response = self._raw_client.embed(
texts=batch_texts,
model=model,
input_type=input_type,
embedding_types=embedding_types,
truncate=truncate,
request_options=request_options,
)

# Parse embeddings from response incrementally
parser = StreamingEmbedParser(response._response, batch_texts)
for embedding in parser.iter_embeddings():
# The parser tracks text index per embedding type
# Adjust text reference to use batch_texts mapping
text_index_in_batch = batch_texts.index(embedding.text) if embedding.text in batch_texts else 0
embedding.index = batch_start + text_index_in_batch
yield embedding

def rerank(
self,
*,
Expand Down
Loading