Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/browser/stores/WorkspaceStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,15 @@ export class WorkspaceStore {
data: WorkspaceChatMessage
) => void
> = {
"stream-pending": (workspaceId, aggregator, data) => {
aggregator.handleStreamPending(data as never);
if (this.onModelUsed) {
this.onModelUsed((data as { model: string }).model);
}
this.states.bump(workspaceId);
// Bump usage store so liveUsage can show the current model even before streaming starts
this.usageStore.bump(workspaceId);
},
"stream-start": (workspaceId, aggregator, data) => {
aggregator.handleStreamStart(data as never);
if (this.onModelUsed) {
Expand Down Expand Up @@ -484,7 +493,7 @@ export class WorkspaceStore {
name: metadata?.name ?? workspaceId, // Fall back to ID if metadata missing
messages: aggregator.getDisplayedMessages(),
queuedMessage: this.queuedMessages.get(workspaceId) ?? null,
canInterrupt: activeStreams.length > 0,
canInterrupt: activeStreams.length > 0 || aggregator.hasInFlightStreams(),
isCompacting: aggregator.isCompacting(),
awaitingUserQuestion: aggregator.hasAwaitingUserQuestion(),
loading: !hasMessages && !isCaughtUp,
Expand Down Expand Up @@ -969,7 +978,8 @@ export class WorkspaceStore {
// Check if there's an active stream in buffered events (reconnection scenario)
const pendingEvents = this.pendingStreamEvents.get(workspaceId) ?? [];
const hasActiveStream = pendingEvents.some(
(event) => "type" in event && event.type === "stream-start"
(event) =>
"type" in event && (event.type === "stream-start" || event.type === "stream-pending")
);

// Load historical messages first
Expand Down
52 changes: 52 additions & 0 deletions src/browser/utils/messages/StreamingMessageAggregator.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,58 @@ describe("StreamingMessageAggregator", () => {
expect(aggregator.getCurrentTodos()).toHaveLength(0);
});

test("should clear in-flight streams when pending stream aborts", () => {
const aggregator = new StreamingMessageAggregator(TEST_CREATED_AT);

aggregator.handleStreamPending({
type: "stream-pending",
workspaceId: "test-workspace",
messageId: "msg1",
historySequence: 1,
model: "claude-3-5-sonnet-20241022",
});

expect(aggregator.hasInFlightStreams()).toBe(true);

aggregator.handleStreamAbort({
type: "stream-abort",
workspaceId: "test-workspace",
messageId: "msg1",
metadata: {},
});

expect(aggregator.hasInFlightStreams()).toBe(false);
});

test("should surface stream-error when tracked stream errors before stream-start", () => {
const aggregator = new StreamingMessageAggregator(TEST_CREATED_AT);

aggregator.handleStreamPending({
type: "stream-pending",
workspaceId: "test-workspace",
messageId: "msg1",
historySequence: 1,
model: "claude-3-5-sonnet-20241022",
});

aggregator.handleStreamError({
type: "stream-error",
messageId: "msg1",
error: "Boom",
errorType: "unknown",
});

expect(aggregator.hasInFlightStreams()).toBe(false);

const displayed = aggregator.getDisplayedMessages();
const errorMsg = displayed.find((m) => m.type === "stream-error");
expect(errorMsg).toBeDefined();
if (errorMsg?.type === "stream-error") {
expect(errorMsg.error).toBe("Boom");
expect(errorMsg.errorType).toBe("unknown");
}
});

test("should reconstruct todos on reload ONLY when reconnecting to active stream", () => {
const aggregator = new StreamingMessageAggregator(TEST_CREATED_AT);

Expand Down
144 changes: 100 additions & 44 deletions src/browser/utils/messages/StreamingMessageAggregator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import type {
} from "@/common/types/message";
import { createMuxMessage } from "@/common/types/message";
import type {
StreamPendingEvent,
StreamStartEvent,
StreamDeltaEvent,
UsageDeltaEvent,
Expand Down Expand Up @@ -52,6 +53,17 @@ interface StreamingContext {
model: string;
}

type InFlightStreamState =
| {
phase: "pending";
pendingAt: number;
model: string;
}
| {
phase: "active";
context: StreamingContext;
};

/**
* Check if a tool result indicates success (for tools that return { success: boolean })
*/
Expand Down Expand Up @@ -136,7 +148,9 @@ function mergeAdjacentParts(parts: MuxMessage["parts"]): MuxMessage["parts"] {

export class StreamingMessageAggregator {
private messages = new Map<string, MuxMessage>();
private activeStreams = new Map<string, StreamingContext>();

// Streams that are in-flight (pending: `stream-pending` received; active: `stream-start` received).
private inFlightStreams = new Map<string, InFlightStreamState>();

// Simple cache for derived values (invalidated on every mutation)
private cachedAllMessages: MuxMessage[] | null = null;
Expand Down Expand Up @@ -336,14 +350,14 @@ export class StreamingMessageAggregator {
* Called by handleStreamEnd, handleStreamAbort, and handleStreamError.
*
* Clears:
* - Active stream tracking (this.activeStreams)
* - In-flight stream tracking (this.inFlightStreams)
* - Current TODOs (this.currentTodos) - reconstructed from history on reload
*
* Does NOT clear:
* - agentStatus - persists after stream completion to show last activity
*/
private cleanupStreamState(messageId: string): void {
this.activeStreams.delete(messageId);
this.inFlightStreams.delete(messageId);
// Clear todos when stream ends - they're stream-scoped state
// On reload, todos will be reconstructed from completed tool_write calls in history
this.currentTodos = [];
Expand Down Expand Up @@ -461,21 +475,31 @@ export class StreamingMessageAggregator {
this.pendingStreamStartTime = time;
}

hasInFlightStreams(): boolean {
return this.inFlightStreams.size > 0;
}
getActiveStreams(): StreamingContext[] {
return Array.from(this.activeStreams.values());
const active: StreamingContext[] = [];
for (const stream of this.inFlightStreams.values()) {
if (stream.phase === "active") active.push(stream.context);
}
return active;
}

/**
* Get the messageId of the first active stream (for token tracking)
* Returns undefined if no streams are active
*/
getActiveStreamMessageId(): string | undefined {
return this.activeStreams.keys().next().value;
for (const [messageId, stream] of this.inFlightStreams.entries()) {
if (stream.phase === "active") return messageId;
}
return undefined;
}

isCompacting(): boolean {
for (const context of this.activeStreams.values()) {
if (context.isCompacting) {
for (const stream of this.inFlightStreams.values()) {
if (stream.phase === "active" && stream.context.isCompacting) {
return true;
}
}
Expand All @@ -484,8 +508,13 @@ export class StreamingMessageAggregator {

getCurrentModel(): string | undefined {
// If there's an active stream, return its model
for (const context of this.activeStreams.values()) {
return context.model;
for (const stream of this.inFlightStreams.values()) {
if (stream.phase === "active") return stream.context.model;
}

// If we're pending (stream-pending), return that model
for (const stream of this.inFlightStreams.values()) {
if (stream.phase === "pending") return stream.model;
}

// Otherwise, return the model from the most recent assistant message
Expand All @@ -501,12 +530,14 @@ export class StreamingMessageAggregator {
}

clearActiveStreams(): void {
this.activeStreams.clear();
this.setPendingStreamStartTime(null);
this.inFlightStreams.clear();
this.invalidateCache();
}

clear(): void {
this.messages.clear();
this.activeStreams.clear();
this.inFlightStreams.clear();
this.invalidateCache();
}

Expand All @@ -529,8 +560,24 @@ export class StreamingMessageAggregator {
}

// Unified event handlers that encapsulate all complex logic
handleStreamPending(data: StreamPendingEvent): void {
// Clear pending stream start timestamp - backend has accepted the request.
this.setPendingStreamStartTime(null);

const existing = this.inFlightStreams.get(data.messageId);
if (existing?.phase === "active") return;

this.inFlightStreams.set(data.messageId, {
phase: "pending",
pendingAt: Date.now(),
model: data.model,
});

this.invalidateCache();
}

handleStreamStart(data: StreamStartEvent): void {
// Clear pending stream start timestamp - stream has started
// Clear pending stream start timestamp - stream has started.
this.setPendingStreamStartTime(null);

// NOTE: We do NOT clear agentStatus or currentTodos here.
Expand All @@ -551,7 +598,7 @@ export class StreamingMessageAggregator {

// Use messageId as key - ensures only ONE stream per message
// If called twice (e.g., during replay), second call safely overwrites first
this.activeStreams.set(data.messageId, context);
this.inFlightStreams.set(data.messageId, { phase: "active", context });

// Create initial streaming message with empty parts (deltas will append)
const streamingMessage = createMuxMessage(data.messageId, "assistant", "", {
Expand Down Expand Up @@ -583,7 +630,8 @@ export class StreamingMessageAggregator {

handleStreamEnd(data: StreamEndEvent): void {
// Direct lookup by messageId - O(1) instead of O(n) find
const activeStream = this.activeStreams.get(data.messageId);
const stream = this.inFlightStreams.get(data.messageId);
const activeStream = stream?.phase === "active" ? stream.context : undefined;

if (activeStream) {
// Normal streaming case: we've been tracking this stream from the start
Expand Down Expand Up @@ -650,9 +698,10 @@ export class StreamingMessageAggregator {

handleStreamAbort(data: StreamAbortEvent): void {
// Direct lookup by messageId
const activeStream = this.activeStreams.get(data.messageId);
const stream = this.inFlightStreams.get(data.messageId);
if (!stream) return;

if (activeStream) {
if (stream.phase === "active") {
// Mark the message as interrupted and merge metadata (consistent with handleStreamEnd)
const message = this.messages.get(data.messageId);
if (message?.metadata) {
Expand All @@ -665,36 +714,16 @@ export class StreamingMessageAggregator {
// Compact parts even on abort - still reduces memory for partial messages
this.compactMessageParts(message);
}

// Clean up stream-scoped state (active stream tracking, TODOs)
this.cleanupStreamState(data.messageId);
this.invalidateCache();
}

// Always clean up stream-scoped state (pending or active) to avoid wedging canInterrupt=true.
this.cleanupStreamState(data.messageId);
this.invalidateCache();
}

handleStreamError(data: StreamErrorMessage): void {
// Direct lookup by messageId
const activeStream = this.activeStreams.get(data.messageId);

if (activeStream) {
// Mark the message with error metadata
const message = this.messages.get(data.messageId);
if (message?.metadata) {
message.metadata.partial = true;
message.metadata.error = data.error;
message.metadata.errorType = data.errorType;

// Compact parts even on error - still reduces memory for partial messages
this.compactMessageParts(message);
}

// Clean up stream-scoped state (active stream tracking, TODOs)
this.cleanupStreamState(data.messageId);
this.invalidateCache();
} else {
// Pre-stream error (e.g., API key not configured before streaming starts)
// Create a synthetic error message since there's no active stream to attach to
// Get the highest historySequence from existing messages so this appears at the end
const createSyntheticErrorMessage = (): void => {
// Get the highest historySequence from existing messages so this appears at the end.
const maxSequence = Math.max(
0,
...Array.from(this.messages.values()).map((m) => m.metadata?.historySequence ?? 0)
Expand All @@ -712,8 +741,35 @@ export class StreamingMessageAggregator {
},
};
this.messages.set(data.messageId, errorMessage);
};

const isTrackedStream = this.inFlightStreams.has(data.messageId);

if (isTrackedStream) {
// Mark the message with error metadata
const message = this.messages.get(data.messageId);
if (message?.metadata) {
message.metadata.partial = true;
message.metadata.error = data.error;
message.metadata.errorType = data.errorType;

// Compact parts even on error - still reduces memory for partial messages
this.compactMessageParts(message);
} else {
// Stream errored before stream-start created a message (pending-phase).
createSyntheticErrorMessage();
}

// Clean up stream-scoped state (active/connecting tracking, TODOs)
this.cleanupStreamState(data.messageId);
this.invalidateCache();
return;
}

// Pre-stream error (e.g., API key not configured before streaming starts)
// Create a synthetic error message since there's no tracked stream to attach to.
createSyntheticErrorMessage();
this.invalidateCache();
}

handleToolCallStart(data: ToolCallStartEvent): void {
Expand Down Expand Up @@ -844,7 +900,7 @@ export class StreamingMessageAggregator {

handleReasoningEnd(_data: ReasoningEndEvent): void {
// Reasoning-end is just a signal - no state to update
// Streaming status is inferred from activeStreams in getDisplayedMessages
// Streaming status is inferred from inFlightStreams in getDisplayedMessages
this.invalidateCache();
}

Expand Down Expand Up @@ -1035,7 +1091,7 @@ export class StreamingMessageAggregator {

// Check if this message has an active stream (for inferring streaming status)
// Direct Map.has() check - O(1) instead of O(n) iteration
const hasActiveStream = this.activeStreams.has(message.id);
const hasActiveStream = this.inFlightStreams.get(message.id)?.phase === "active";

// Merge adjacent text/reasoning parts for display
const mergedParts = mergeAdjacentParts(message.parts);
Expand Down
Loading