Compare commits

...

14 Commits

Author SHA1 Message Date
Zamil Majdy
bac7b9efb9 fix(copilot): update shared counter after collision detection
When collision detection in add_chat_messages_batch retries with a higher
sequence number, the actual persisted message count may differ from
len(session.messages). This commit ensures the shared counter
(saved_msg_count_ref) used by the streaming loop and long-running callback
stays synchronized with the actual DB state.

Changes:
- Modified add_chat_messages_batch to return tuple[list[ChatMessage], int]
  where the int is the final message count after collision resolution
- Updated _save_session_to_db and upsert_chat_session to propagate the
  final count up the call chain
- Updated all callers in sdk/service.py to use the returned count instead
  of len(session.messages) when updating saved_msg_count_ref
- Updated all other callers in service.py and tests to handle tuple return
2026-02-20 18:58:02 +07:00
Zamil Majdy
6e1941d7ae feat(copilot): implement session locking to prevent concurrent streams
- Add stream_id (using task_id) to uniquely identify each stream
- Acquire exclusive lock (Redis SET NX EX) when starting a stream
- Release lock in finally block using Lua script (atomic compare-and-delete)
- Return error if another stream is already active for the session
- Lock TTL is 1 hour (matches stream_ttl) with automatic cleanup

This prevents:
- Message duplication from concurrent streams
- Race conditions in message saves
- Confusing UX with multiple AI responses
- Frontend reconnecting while existing stream is active
- Multiple browser tabs streaming to same session
2026-02-20 18:28:35 +07:00
Zamil Majdy
129b992059 feat(copilot): increase long-running operation TTL to 1 hour
- Increase long_running_operation_ttl from 600s (10min) to 3600s (1hour)
- Match stream_ttl duration for consistency
- Add clarifying description about deduplication lock purpose

Some operations (like complex agent runs) can take longer than 10 minutes.
The stream_registry heartbeat (publish_chunk) already keeps operations alive,
so this TTL is just a safety net for deduplication.
2026-02-20 18:22:26 +07:00
Zamil Majdy
1b82a55eca chore: remove obsolete plan file
Plan was completed and changes are now in the PR. No need to keep the plan file.
2026-02-20 18:21:00 +07:00
Zamil Majdy
9d4697e859 refactor(copilot): replace COUNT with MAX for sequence tracking
- Rename get_max_sequence() to get_next_sequence() returning MAX+1
- Replace all get_chat_session_message_count() calls with get_next_sequence()
- Remove old get_chat_session_message_count() function
- Update db_manager.py to export get_next_sequence

Using MAX(sequence)+1 is more robust than COUNT(*) because:
- Immune to deleted messages
- Handles gaps in sequence numbers correctly
- Simpler collision detection logic
2026-02-20 18:20:29 +07:00
Zamil Majdy
366547e448 refactor(copilot): remove confusing 'Layer' comments from code
- Remove '(Layer 3: defense-in-depth)' annotations
- Replace with clearer explanations of what the code does
- Makes the code easier to understand without implementation history
2026-02-20 18:18:25 +07:00
Zamil Majdy
af491b5511 refactor(copilot): replace upsert with collision detection for concurrent message saves
- Use create() with MAX(sequence) retry instead of upsert()
- Query DB only on collision (not every save) for better performance
- Remove Layer 2 DB queries from incremental saves in streaming loop
- Add get_max_sequence() helper using raw SQL for robustness
- Collision detection retries up to 3 times on unique constraint errors

This approach:
- Optimizes common case (no collision) - no extra DB queries
- Handles concurrent writes via automatic retry with correct sequence
- Uses MAX(sequence) instead of COUNT for more robust offset calculation
2026-02-20 18:09:34 +07:00
Zamil Majdy
6acefee6f3 fix(copilot): defense-in-depth for concurrent message saves (all 3 layers)
Implements three complementary layers to prevent unique constraint violations
on (sessionId, sequence) caused by concurrent writers during SDK streaming:

**Layer 1: Upsert (already in PR)**
- add_chat_messages_batch uses upsert() instead of create()
- Explicitly constructs update_data excluding Session and sequence
- Final safety net: duplicate sequences update instead of crash

**Layer 2: Query DB Before Each Save (NEW)**
- Query get_chat_session_message_count() before each save
- DB is source of truth, prevents using stale in-memory counter
- Applied to: long-running callback + 2 incremental saves
- Trade-off: Extra COUNT query (~1-2ms), but prevents race

**Layer 3: Shared Counter (NEW)**
- saved_msg_count_ref as mutable list[int] shared between:
  - Streaming loop (incremental saves)
  - Long-running callback (_build_long_running_callback)
- Both writers update it after successful save
- Keeps in-memory tracking accurate for performance

**Why all three:**
- Layer 2 alone: adds DB queries (performance cost)
- Layer 3 alone: doesn't handle external writers
- Layer 1 alone: may silently overwrite data
- Together: correctness + performance + safety net

Files:
- backend/copilot/db.py - Layer 1 (upsert with explicit update_data)
- backend/copilot/sdk/service.py - Layers 2 & 3

Fixes race where long-running tools (create_agent, edit_agent) would
append messages behind streaming loop's back, causing stale counter.

Addresses PR review comments and Discord analysis.
2026-02-20 18:02:00 +07:00
Zamil Majdy
eb4650fbb8 fix(copilot): explicitly construct update_data for better type safety
Instead of filtering from data dict, explicitly build update_data with
only the fields that should be updated. This is safer and makes it
obvious what fields are being updated in the upsert operation.

Addresses PR review comment about exhaustive field construction.
2026-02-20 17:53:52 +07:00
Zamil Majdy
8bdf83128e fix(copilot): address CodeRabbit review - add type safety and exclude sequence from update
- Add ChatMessageUpdateInput import for type-safe update payload
- Exclude both 'Session' and 'sequence' from update_data (sequence is part of composite key)
- Cast update_data to ChatMessageUpdateInput for type checking
- Update docstring to document upsert semantics and idempotency
2026-02-20 17:49:11 +07:00
Zamil Majdy
a1d5b99226 Merge branch 'dev' into otto/fix-chat-messages-batch-upsert 2026-02-20 17:48:02 +07:00
Otto
0450ea5313 fix(copilot): use upsert in add_chat_messages_batch to handle duplicate sequences
Concurrent writers (incremental streaming saves and long-running tool
callbacks) can race to persist messages with the same (sessionId, sequence)
pair, causing unique constraint violations.

Replace prisma create() with upsert() so duplicate sequences update the
existing row instead of failing. This is safe because later writes always
contain the most complete data (e.g. accumulated assistant text).
2026-02-20 09:59:56 +00:00
Zamil Majdy
9cdcd6793f fix(copilot): remove stream timeout, add error propagation to frontend (#12175)
## Summary

Fixes critical reliability issues where long-running copilot sessions
were forcibly terminated and failures showed no error messages to users.

## Issues Fixed

1. **Silent failures**: Tasks failed but frontend showed "stopped" with
zero explanation
2. **Premature timeout**: Sessions auto-expired after 5 minutes even
when actively running

## Changes

### Error propagation to frontend
- Add `error_message` parameter to `mark_task_completed()`
- When `status="failed"`, publish `StreamError` before `StreamFinish` so
frontend displays reason
- Update all failure callers with specific error messages:
  - Session not found: `"Session {id} not found"`
  - Tool setup failed: `"Failed to setup tool {name}: {error}"`  
  - Task cancelled: `"Task was cancelled"`

### Remove stream timeout
- Delete `stream_timeout` config (was 300s/5min)
- Remove auto-expiry logic in `get_active_task_for_session()`
- Sessions now run indefinitely — user controls stopping via UI

## Why

**Auto-expiry was broken:**
- Used `created_at` (task start) not last activity
- SDK sessions with multiple LLM calls + subagent Tasks easily run
20-30+ minutes
- A task publishing chunks every second still got killed at 5min mark
- Hard timeout is inappropriate for long-running AI agents

**Error propagation was missing:**
- `mark_task_completed(status="failed")` only sent `StreamFinish`
- No `StreamError` event = frontend had no message to show user
- Backend logs showed errors but user saw nothing

## Test Plan

- [x] Formatter, linter, type-check pass
- [ ] Start a copilot session with Task tool (spawns subagent)
- [ ] Verify session runs beyond 5 minutes without auto-expiry
- [ ] Cancel a running session → frontend shows "Task was cancelled"
error
- [ ] Trigger a tool setup failure → frontend shows error message
- [ ] Session continues running until user clicks stop or task completes

## Files Changed

- `backend/copilot/config.py` — removed `stream_timeout`
- `backend/copilot/stream_registry.py` — removed auto-expiry, added
error propagation
- `backend/copilot/service.py` — error messages for 2 failure paths
- `backend/copilot/executor/processor.py` — error message for
cancellation
2026-02-20 09:16:22 +00:00
Zamil Majdy
fc64f83331 fix(copilot): SDK streaming reliability, parallel tools, incremental saves, frontend reconnection (#12173)
## Summary

Fixes multiple reliability issues in the copilot's Claude Agent SDK
streaming pipeline — tool outputs getting stuck, parallel tool calls
flushing prematurely, messages lost on page refresh, and SSE
reconnection failures.

## Changes

### Backend: Streaming loop rewrite (`sdk/service.py`)
- **Non-cancelling heartbeat pattern**: Replace `asyncio.timeout()` with
`asyncio.wait()` for SDK message iteration. The old approach corrupted
the SDK's internal anyio memory stream when timeouts fired
mid-`__anext__()`, causing `StopAsyncIteration` on the next call and
silently dropping all in-flight tool results.
- **Hook synchronization**: Add `wait_for_stash()` before
`convert_message()` — the SDK fires PostToolUse hooks via `start_soon()`
(fire-and-forget), so the next message can arrive before the hook
stashes its output. The new asyncio.Event-based mechanism bridges this
gap without arbitrary sleeps.
- **Error handling**: Add `asyncio.CancelledError` handling at both
inner (streaming loop) and outer (session) levels, plus pending task
cleanup in `finally` block to prevent leaked coroutines. Catch
`Exception` from `done.pop().result()` for SDK error messages.
- **Safety-net flush**: After streaming loop ends, flush any remaining
unresolved tool calls so the frontend stops showing spinners even if the
stream drops unexpectedly.
- **StreamFinish fallback**: Emit `StreamFinishStep` + `StreamFinish`
when stream ends without `ResultMessage` (StopAsyncIteration) so the
frontend transitions to "ready" state.
- **Incremental session saves**: Save session to PostgreSQL after each
tool input/output event (not just at stream end), so page refresh and
other devices see recent messages.
- **Enhanced logging**: All log lines now include `session_id[:12]`
prefix and tool call resolution state (unresolved/current/resolved
counts).

### Backend: Response adapter (`sdk/response_adapter.py`)
- **Parallel tool call support**: Skip `_flush_unresolved_tool_calls()`
when an AssistantMessage contains only ToolUseBlocks (parallel
continuation) — the prior tools are still executing concurrently and
haven't finished yet.
- **Duplicate output prevention**: Skip already-resolved tool results in
both UserMessage (ToolResultBlock) and parent_tool_use_id handling to
prevent duplicate `StreamToolOutputAvailable` events.
- **`has_unresolved_tool_calls` property**: Used by the streaming loop
to decide whether to wait for PostToolUse hooks.
- **`session_id` parameter**: Passed through for structured logging.

### Backend: Hook synchronization (`sdk/tool_adapter.py`)
- **`_stash_event` ContextVar**: asyncio.Event signaled by
`stash_pending_tool_output()` whenever a PostToolUse hook stashes
output.
- **`wait_for_stash()`**: Awaits the event with configurable timeout —
replaces the racy "hope the hook finished" approach.

### Backend: Security hooks (`sdk/security_hooks.py`)
- Enhanced logging in `post_tool_use_hook` — log whether tool is
built-in, preview of stashed output, warning when `tool_response` is
None.

### Backend: Incremental save optimization (`model.py`)
- **`existing_message_count` parameter** on `upsert_chat_session`: Skip
the DB query to count existing messages when the caller already tracks
this (streaming loop).
- **`skip_existence_check` parameter** on `_save_session_to_db`: Skip
the `get_chat_session` existence query when we know the session row
already exists. Reduces from 4 DB round trips to 2 per incremental save.

### Backend: SDK version bump (`pyproject.toml`, `poetry.lock`)
- Bump `claude-agent-sdk` from `^0.1.0` to `^0.1.39`.

### Backend: New tests
- **`sdk_compat_test.py`** (new file): SDK compatibility tests — verify
the installed SDK exposes every class, attribute, and method the copilot
integration relies on. Catches SDK upgrade breakage immediately.
- **`response_adapter_test.py`**: 9 new tests covering
flush-at-ResultMessage, flush-at-next-AssistantMessage, stashed output
flush, wait_for_stash signaling/timeout/fast-path, parallel tool call
non-premature-flush, text-message flush of prior tools, and
already-resolved tool skip in UserMessage.

### Frontend: Session hydration (`convertChatSessionToUiMessages.ts`)
- **`isComplete` option**: When session has no active stream, mark
dangling tool calls (no output in DB) as `output-available` with empty
output — stops stale spinners after page refresh.

### Frontend: Chat session hook (`useChatSession.ts`)
- Reorder `hasActiveStream` memo before `hydratedMessages` so
`isComplete` flag is available.
- Pass `{ isComplete: !hasActiveStream }` to
`convertChatSessionMessagesToUiMessages`.

### Frontend: Copilot page hook (`useCopilotPage.ts`)
- **Cache invalidation on stream end**: Invalidate React Query session
cache when stream transitions active → idle, so next hydration fetches
fresh messages from backend (staleTime: Infinity otherwise keeps stale
data).
- **Resume ref reset**: Reset `hasResumedRef` on stream end to allow
re-resume if SSE drops but backend task is still running.
- **Remove old `resolveInProgressTools` effect**: Replaced by
backend-side safety-net flush + hydration-time `isComplete` marking.

## Test plan
- [ ] Existing copilot tests pass (`pytest backend/copilot/ -x -q`)
- [ ] SDK compat tests pass (`pytest
backend/copilot/sdk/sdk_compat_test.py -v`)
- [ ] Tool outputs (bash_exec, web_fetch, WebSearch) appear in the UI
instead of getting stuck
- [ ] Parallel tool calls (e.g. multiple WebSearch) complete and display
results without premature flush
- [ ] Page refresh during active stream reconnects and recovers messages
- [ ] Opening session from another device shows recent tool results
- [ ] SSE drop → automatic reconnection without losing messages
- [ ] Long-running tools (create_agent) still delegate to background
infrastructure
2026-02-20 08:25:08 +00:00
20 changed files with 1306 additions and 251 deletions

View File

@@ -27,7 +27,6 @@ class ChatConfig(BaseSettings):
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# Streaming Configuration
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(
default=3,
description="Max retries for fallback path (SDK handles retries internally)",
@@ -39,8 +38,10 @@ class ChatConfig(BaseSettings):
# Long-running operation configuration
long_running_operation_ttl: int = Field(
default=600,
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
default=3600,
description="TTL in seconds for long-running operation deduplication lock "
"(1 hour, matches stream_ttl). Prevents duplicate operations if pod dies. "
"For longer operations, the stream_registry heartbeat keeps them alive.",
)
# Stream registry configuration for SSE reconnection

View File

@@ -132,58 +132,97 @@ async def add_chat_messages_batch(
session_id: str,
messages: list[dict[str, Any]],
start_sequence: int,
) -> list[ChatMessage]:
) -> tuple[list[ChatMessage], int]:
"""Add multiple messages to a chat session in a batch.
Uses a transaction for atomicity - if any message creation fails,
the entire batch is rolled back.
Uses collision detection with retry: tries to create messages starting
at start_sequence. If a unique constraint violation occurs (e.g., the
streaming loop and long-running callback race), queries MAX(sequence)
and retries with the correct next sequence number. This avoids
unnecessary upserts and DB queries in the common case (no collision).
Returns:
Tuple of (messages, final_message_count) where final_message_count
is the total number of messages in the session after insertion.
This allows callers to update their counters even when collision
detection adjusts start_sequence.
"""
if not messages:
return []
# No messages to add - return current count
return [], start_sequence
created_messages = []
max_retries = 3
for attempt in range(max_retries):
try:
created_messages = []
async with db.transaction() as tx:
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields
# set to None. We only include fields that have values, then cast.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": start_sequence + i,
}
async with db.transaction() as tx:
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields
# set to None. We only include fields that have values, then cast.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": start_sequence + i,
}
# Add optional string fields
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
# Add optional string fields
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
created = await PrismaChatMessage.prisma(tx).create(
data=cast(ChatMessageCreateInput, data)
)
created_messages.append(created)
created = await PrismaChatMessage.prisma(tx).create(
data=cast(ChatMessageCreateInput, data)
# Update session's updatedAt timestamp within the same transaction.
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
# separately via update_chat_session() after streaming completes.
await PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
)
# Return messages and final message count (for shared counter sync)
final_count = start_sequence + len(messages)
return [ChatMessage.from_db(m) for m in created_messages], final_count
except Exception as e:
# Check if it's a unique constraint violation
error_msg = str(e).lower()
is_unique_constraint = (
"unique constraint" in error_msg or "duplicate key" in error_msg
)
created_messages.append(created)
# Update session's updatedAt timestamp within the same transaction.
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
# separately via update_chat_session() after streaming completes.
await PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
)
if is_unique_constraint and attempt < max_retries - 1:
# Collision detected - query MAX(sequence)+1 and retry with correct offset
logger.info(
f"Collision detected for session {session_id} at sequence "
f"{start_sequence}, querying DB for latest sequence"
)
start_sequence = await get_next_sequence(session_id)
logger.info(
f"Retrying batch insert with start_sequence={start_sequence}"
)
continue
else:
# Not a collision or max retries exceeded - propagate error
raise
return [ChatMessage.from_db(m) for m in created_messages]
# Should never reach here due to raise in exception handler
raise RuntimeError(f"Failed to insert messages after {max_retries} attempts")
async def get_user_chat_sessions(
@@ -237,10 +276,23 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
return False
async def get_chat_session_message_count(session_id: str) -> int:
"""Get the number of messages in a chat session."""
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
return count
async def get_next_sequence(session_id: str) -> int:
"""Get the next sequence number for a new message in this session.
Uses MAX(sequence) + 1 for robustness. Returns 0 if no messages exist.
More robust than COUNT(*) because it's immune to deleted messages.
"""
result = await db.prisma.query_raw(
"""
SELECT COALESCE(MAX(sequence) + 1, 0) as next_seq
FROM "ChatMessage"
WHERE "sessionId" = $1
""",
session_id,
)
if not result or len(result) == 0:
return 0
return int(result[0]["next_seq"])
async def update_tool_message_content(

View File

@@ -266,7 +266,11 @@ class CoPilotProcessor:
except asyncio.CancelledError:
log.info("Task cancelled")
await stream_registry.mark_task_completed(entry.task_id, status="failed")
await stream_registry.mark_task_completed(
entry.task_id,
status="failed",
error_message="Task was cancelled",
)
raise
except Exception as e:

View File

@@ -432,13 +432,27 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
return session
async def upsert_chat_session(session: ChatSession) -> ChatSession:
async def upsert_chat_session(
session: ChatSession,
*,
existing_message_count: int | None = None,
) -> tuple[ChatSession, int]:
"""Update a chat session in both cache and database.
Uses session-level locking to prevent race conditions when concurrent
operations (e.g., background title update and main stream handler)
attempt to upsert the same session simultaneously.
Args:
existing_message_count: If provided, skip the DB query to count
existing messages. The caller is responsible for tracking this
accurately. Useful for incremental saves in a streaming loop
where the caller already knows how many messages are persisted.
Returns:
Tuple of (session, final_message_count) where final_message_count is
the actual persisted message count after collision detection adjustments.
Raises:
DatabaseError: If the database write fails. The cache is still updated
as a best-effort optimization, but the error is propagated to ensure
@@ -450,15 +464,21 @@ async def upsert_chat_session(session: ChatSession) -> ChatSession:
async with lock:
# Get existing message count from DB for incremental saves
existing_message_count = await chat_db().get_chat_session_message_count(
session.session_id
)
if existing_message_count is None:
existing_message_count = await chat_db().get_next_sequence(
session.session_id
)
db_error: Exception | None = None
final_count = existing_message_count
# Save to database (primary storage)
try:
await _save_session_to_db(session, existing_message_count)
final_count = await _save_session_to_db(
session,
existing_message_count,
skip_existence_check=existing_message_count > 0,
)
except Exception as e:
logger.error(
f"Failed to save session {session.session_id} to database: {e}"
@@ -485,25 +505,38 @@ async def upsert_chat_session(session: ChatSession) -> ChatSession:
f"Failed to persist chat session {session.session_id} to database"
) from db_error
return session
return session, final_count
async def _save_session_to_db(
session: ChatSession, existing_message_count: int
) -> None:
"""Save or update a chat session in the database."""
session: ChatSession,
existing_message_count: int,
*,
skip_existence_check: bool = False,
) -> int:
"""Save or update a chat session in the database.
Args:
skip_existence_check: When True, skip the ``get_chat_session`` query
and assume the session row already exists. Saves one DB round trip
for incremental saves during streaming.
Returns:
Final message count after save (accounting for collision detection).
"""
db = chat_db()
# Check if session exists in DB
existing = await db.get_chat_session(session.session_id)
if not skip_existence_check:
# Check if session exists in DB
existing = await db.get_chat_session(session.session_id)
if not existing:
# Create new session
await db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
)
existing_message_count = 0
if not existing:
# Create new session
await db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
)
existing_message_count = 0
# Calculate total tokens from usage
total_prompt = sum(u.prompt_tokens for u in session.usage)
@@ -521,6 +554,7 @@ async def _save_session_to_db(
# Add new messages (only those after existing count)
new_messages = session.messages[existing_message_count:]
final_count = existing_message_count
if new_messages:
messages_data = []
for msg in new_messages:
@@ -540,12 +574,14 @@ async def _save_session_to_db(
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await db.add_chat_messages_batch(
_, final_count = await db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,
start_sequence=existing_message_count,
)
return final_count
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
"""Atomically append a message to a session and persist it.
@@ -562,9 +598,7 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
raise ValueError(f"Session {session_id} not found")
session.messages.append(message)
existing_message_count = await chat_db().get_chat_session_message_count(
session_id
)
existing_message_count = await chat_db().get_next_sequence(session_id)
try:
await _save_session_to_db(session, existing_message_count)

View File

@@ -60,7 +60,7 @@ async def test_chatsession_redis_storage(setup_test_user, test_user_id):
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
s, _ = await upsert_chat_session(s)
s2 = await get_chat_session(
session_id=s.session_id,
@@ -77,7 +77,7 @@ async def test_chatsession_redis_storage_user_id_mismatch(
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
s, _ = await upsert_chat_session(s)
s2 = await get_chat_session(s.session_id, "different_user_id")
@@ -94,7 +94,7 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
s.messages = messages # Contains user, assistant, and tool messages
assert s.session_id is not None, "Session id is not set"
# Upsert to save to both cache and DB
s = await upsert_chat_session(s)
s, _ = await upsert_chat_session(s)
# Clear the Redis cache to force DB load
redis_key = f"chat:session:{s.session_id}"

View File

@@ -47,8 +47,9 @@ class SDKResponseAdapter:
text blocks, tool calls, and message lifecycle.
"""
def __init__(self, message_id: str | None = None):
def __init__(self, message_id: str | None = None, session_id: str | None = None):
self.message_id = message_id or str(uuid.uuid4())
self.session_id = session_id
self.text_block_id = str(uuid.uuid4())
self.has_started_text = False
self.has_ended_text = False
@@ -61,6 +62,11 @@ class SDKResponseAdapter:
"""Set the task ID for reconnection support."""
self.task_id = task_id
@property
def has_unresolved_tool_calls(self) -> bool:
"""True when there are tool calls that haven't received output yet."""
return bool(self.current_tool_calls.keys() - self.resolved_tool_calls)
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
"""Convert a single SDK message to Vercel AI SDK format."""
responses: list[StreamBaseResponse] = []
@@ -77,7 +83,12 @@ class SDKResponseAdapter:
elif isinstance(sdk_message, AssistantMessage):
# Flush any SDK built-in tool calls that didn't get a UserMessage
# result (e.g. WebSearch, Read handled internally by the CLI).
self._flush_unresolved_tool_calls(responses)
# BUT skip flush when this AssistantMessage is a parallel tool
# continuation (contains only ToolUseBlocks) — the prior tools
# are still executing concurrently and haven't finished yet.
is_tool_only = all(isinstance(b, ToolUseBlock) for b in sdk_message.content)
if not is_tool_only:
self._flush_unresolved_tool_calls(responses)
# After tool results, the SDK sends a new AssistantMessage for the
# next LLM turn. Open a new step if the previous one was closed.
@@ -118,8 +129,24 @@ class SDKResponseAdapter:
blocks = content if isinstance(content, list) else []
resolved_in_blocks: set[str] = set()
sid = (self.session_id or "?")[:12]
parent_id_preview = getattr(sdk_message, "parent_tool_use_id", None)
logger.info(
"[SDK] [%s] UserMessage: %d blocks, content_type=%s, "
"parent_tool_use_id=%s",
sid,
len(blocks),
type(content).__name__,
parent_id_preview[:12] if parent_id_preview else "None",
)
for block in blocks:
if isinstance(block, ToolResultBlock) and block.tool_use_id:
# Skip if already resolved (e.g. by flush) — the real
# result supersedes the empty flush, but re-emitting
# would confuse the frontend's state machine.
if block.tool_use_id in self.resolved_tool_calls:
continue
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
tool_name = tool_info.get("name", "unknown")
@@ -144,7 +171,11 @@ class SDKResponseAdapter:
# Handle SDK built-in tool results carried via parent_tool_use_id
# instead of (or in addition to) ToolResultBlock content.
parent_id = sdk_message.parent_tool_use_id
if parent_id and parent_id not in resolved_in_blocks:
if (
parent_id
and parent_id not in resolved_in_blocks
and parent_id not in self.resolved_tool_calls
):
tool_info = self.current_tool_calls.get(parent_id, {})
tool_name = tool_info.get("name", "unknown")
@@ -228,11 +259,28 @@ class SDKResponseAdapter:
output, which we pop and emit here before the next ``AssistantMessage``
starts.
"""
unresolved = [
(tid, info.get("name", "unknown"))
for tid, info in self.current_tool_calls.items()
if tid not in self.resolved_tool_calls
]
sid = (self.session_id or "?")[:12]
if not unresolved:
logger.info(
"[SDK] [%s] Flush called but all %d tool(s) already resolved",
sid,
len(self.current_tool_calls),
)
return
logger.info(
"[SDK] [%s] Flushing %d unresolved tool call(s): %s",
sid,
len(unresolved),
", ".join(f"{name}({tid[:12]})" for tid, name in unresolved),
)
flushed = False
for tool_id, tool_info in self.current_tool_calls.items():
if tool_id in self.resolved_tool_calls:
continue
tool_name = tool_info.get("name", "unknown")
for tool_id, tool_name in unresolved:
output = pop_pending_tool_output(tool_name)
if output is not None:
responses.append(
@@ -245,9 +293,12 @@ class SDKResponseAdapter:
)
self.resolved_tool_calls.add(tool_id)
flushed = True
logger.debug(
f"Flushed pending output for built-in tool {tool_name} "
f"(call {tool_id})"
logger.info(
"[SDK] [%s] Flushed stashed output for %s " "(call %s, %d chars)",
sid,
tool_name,
tool_id[:12],
len(output),
)
else:
# No output available — emit an empty output so the frontend
@@ -263,9 +314,14 @@ class SDKResponseAdapter:
)
self.resolved_tool_calls.add(tool_id)
flushed = True
logger.debug(
f"Flushed empty output for unresolved tool {tool_name} "
f"(call {tool_id})"
logger.warning(
"[SDK] [%s] Flushed EMPTY output for unresolved tool %s "
"(call %s) — stash was empty (likely SDK hook race "
"condition: PostToolUse hook hadn't completed before "
"flush was triggered)",
sid,
tool_name,
tool_id[:12],
)
if flushed and self.step_open:

View File

@@ -1,5 +1,8 @@
"""Unit tests for the SDK response adapter."""
import asyncio
import pytest
from claude_agent_sdk import (
AssistantMessage,
ResultMessage,
@@ -27,6 +30,10 @@ from backend.copilot.response_model import (
from .response_adapter import SDKResponseAdapter
from .tool_adapter import MCP_TOOL_PREFIX
from .tool_adapter import _pending_tool_outputs as _pto
from .tool_adapter import _stash_event
from .tool_adapter import stash_pending_tool_output as _stash
from .tool_adapter import wait_for_stash
def _adapter() -> SDKResponseAdapter:
@@ -364,3 +371,310 @@ def test_full_conversation_flow():
"StreamFinishStep", # step 2 closed
"StreamFinish",
]
# -- Flush unresolved tool calls --------------------------------------------
def test_flush_unresolved_at_result_message():
"""Built-in tools (WebSearch) without UserMessage results get flushed at ResultMessage."""
adapter = _adapter()
all_responses: list[StreamBaseResponse] = []
# 1. Init
all_responses.extend(
adapter.convert_message(SystemMessage(subtype="init", data={}))
)
# 2. Tool use (built-in tool — no MCP prefix)
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
],
model="test",
)
)
)
# 3. No UserMessage for this tool — go straight to ResultMessage
all_responses.extend(
adapter.convert_message(
ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
)
)
)
types = [type(r).__name__ for r in all_responses]
assert types == [
"StreamStart",
"StreamStartStep",
"StreamToolInputStart",
"StreamToolInputAvailable",
"StreamToolOutputAvailable", # flushed with empty output
"StreamFinishStep", # step closed by flush
"StreamFinish",
]
# The flushed output should be empty (no stash available)
output_event = [
r for r in all_responses if isinstance(r, StreamToolOutputAvailable)
][0]
assert output_event.toolCallId == "ws-1"
assert output_event.toolName == "WebSearch"
assert output_event.output == ""
def test_flush_unresolved_at_next_assistant_message():
"""Built-in tools get flushed when the next AssistantMessage arrives."""
adapter = _adapter()
all_responses: list[StreamBaseResponse] = []
# 1. Init
all_responses.extend(
adapter.convert_message(SystemMessage(subtype="init", data={}))
)
# 2. Tool use (built-in — no UserMessage will come)
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
],
model="test",
)
)
)
# 3. Next AssistantMessage triggers flush before processing its blocks
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[TextBlock(text="Here are the results")], model="test"
)
)
)
types = [type(r).__name__ for r in all_responses]
assert types == [
"StreamStart",
"StreamStartStep",
"StreamToolInputStart",
"StreamToolInputAvailable",
# Flush at next AssistantMessage:
"StreamToolOutputAvailable",
"StreamFinishStep", # step closed by flush
# New step for continuation text:
"StreamStartStep",
"StreamTextStart",
"StreamTextDelta",
]
def test_flush_with_stashed_output():
"""Stashed output from PostToolUse hook is used when flushing."""
adapter = _adapter()
# Simulate PostToolUse hook stashing output
_pto.set({})
_stash("WebSearch", "Search result: 5 items found")
all_responses: list[StreamBaseResponse] = []
# Tool use
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
],
model="test",
)
)
)
# ResultMessage triggers flush
all_responses.extend(
adapter.convert_message(
ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
)
)
)
output_events = [
r for r in all_responses if isinstance(r, StreamToolOutputAvailable)
]
assert len(output_events) == 1
assert output_events[0].output == "Search result: 5 items found"
# Cleanup
_pto.set({}) # type: ignore[arg-type]
# -- wait_for_stash synchronisation tests --
@pytest.mark.asyncio
async def test_wait_for_stash_signaled():
"""wait_for_stash returns True when stash_pending_tool_output signals."""
_pto.set({})
event = asyncio.Event()
_stash_event.set(event)
# Simulate a PostToolUse hook that stashes output after a short delay
async def delayed_stash():
await asyncio.sleep(0.01)
_stash("WebSearch", "result data")
asyncio.create_task(delayed_stash())
result = await wait_for_stash(timeout=1.0)
assert result is True
assert _pto.get({}).get("WebSearch") == ["result data"]
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_stash_event.set(None)
@pytest.mark.asyncio
async def test_wait_for_stash_timeout():
"""wait_for_stash returns False on timeout when no stash occurs."""
_pto.set({})
event = asyncio.Event()
_stash_event.set(event)
result = await wait_for_stash(timeout=0.05)
assert result is False
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_stash_event.set(None)
@pytest.mark.asyncio
async def test_wait_for_stash_already_stashed():
"""wait_for_stash picks up a stash that happened just before the wait."""
_pto.set({})
event = asyncio.Event()
_stash_event.set(event)
# Stash before waiting — simulates hook completing before message arrives
_stash("Read", "file contents")
# Event is now set; wait_for_stash detects the fast path and returns
# immediately without timing out.
result = await wait_for_stash(timeout=0.05)
assert result is True
# But the stash itself is populated
assert _pto.get({}).get("Read") == ["file contents"]
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_stash_event.set(None)
# -- Parallel tool call tests --
def test_parallel_tool_calls_not_flushed_prematurely():
"""Parallel tool calls should NOT be flushed when the next AssistantMessage
only contains ToolUseBlocks (parallel continuation)."""
adapter = SDKResponseAdapter()
# Init
adapter.convert_message(SystemMessage(subtype="init", data={}))
# First AssistantMessage: tool call #1
msg1 = AssistantMessage(
content=[ToolUseBlock(id="t1", name="WebSearch", input={"q": "foo"})],
model="test",
)
r1 = adapter.convert_message(msg1)
assert any(isinstance(r, StreamToolInputAvailable) for r in r1)
assert adapter.has_unresolved_tool_calls
# Second AssistantMessage: tool call #2 (parallel continuation)
msg2 = AssistantMessage(
content=[ToolUseBlock(id="t2", name="WebSearch", input={"q": "bar"})],
model="test",
)
r2 = adapter.convert_message(msg2)
# No flush should have happened — t1 should NOT have StreamToolOutputAvailable
output_events = [r for r in r2 if isinstance(r, StreamToolOutputAvailable)]
assert len(output_events) == 0, (
f"Tool-only AssistantMessage should not flush prior tools, "
f"but got {len(output_events)} output events"
)
# Both t1 and t2 should still be unresolved
assert "t1" not in adapter.resolved_tool_calls
assert "t2" not in adapter.resolved_tool_calls
def test_text_assistant_message_flushes_prior_tools():
"""An AssistantMessage with text (new turn) should flush unresolved tools."""
adapter = SDKResponseAdapter()
# Init
adapter.convert_message(SystemMessage(subtype="init", data={}))
# Tool call
msg1 = AssistantMessage(
content=[ToolUseBlock(id="t1", name="WebSearch", input={"q": "foo"})],
model="test",
)
adapter.convert_message(msg1)
assert adapter.has_unresolved_tool_calls
# Text AssistantMessage (new turn after tools completed)
msg2 = AssistantMessage(
content=[TextBlock(text="Here are the results")],
model="test",
)
r2 = adapter.convert_message(msg2)
# Flush SHOULD have happened — t1 gets empty output
output_events = [r for r in r2 if isinstance(r, StreamToolOutputAvailable)]
assert len(output_events) == 1
assert output_events[0].toolCallId == "t1"
assert "t1" in adapter.resolved_tool_calls
def test_already_resolved_tool_skipped_in_user_message():
"""A tool result in UserMessage should be skipped if already resolved by flush."""
adapter = SDKResponseAdapter()
adapter.convert_message(SystemMessage(subtype="init", data={}))
# Tool call + flush via text message
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name="WebSearch", input={})],
model="test",
)
)
adapter.convert_message(
AssistantMessage(
content=[TextBlock(text="Done")],
model="test",
)
)
assert "t1" in adapter.resolved_tool_calls
# Now UserMessage arrives with the real result — should be skipped
user_msg = UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="real")])
r = adapter.convert_message(user_msg)
output_events = [r_ for r_ in r if isinstance(r_, StreamToolOutputAvailable)]
assert (
len(output_events) == 0
), "Already-resolved tool should not emit duplicate output"

View File

@@ -0,0 +1,194 @@
"""SDK compatibility tests — verify the claude-agent-sdk public API surface we depend on.
Instead of pinning to a narrow version range, these tests verify that the
installed SDK exposes every class, function, attribute, and method the copilot
integration relies on. If an SDK upgrade removes or renames something these
tests will catch it immediately.
"""
import inspect
import pytest
# ---------------------------------------------------------------------------
# Public types & factories
# ---------------------------------------------------------------------------
def test_sdk_exports_client_and_options():
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
assert inspect.isclass(ClaudeSDKClient)
assert inspect.isclass(ClaudeAgentOptions)
def test_sdk_exports_message_types():
from claude_agent_sdk import (
AssistantMessage,
Message,
ResultMessage,
SystemMessage,
UserMessage,
)
for cls in (AssistantMessage, ResultMessage, SystemMessage, UserMessage):
assert inspect.isclass(cls), f"{cls.__name__} is not a class"
# Message is a Union type alias, just verify it's importable
assert Message is not None
def test_sdk_exports_content_block_types():
from claude_agent_sdk import TextBlock, ToolResultBlock, ToolUseBlock
for cls in (TextBlock, ToolResultBlock, ToolUseBlock):
assert inspect.isclass(cls), f"{cls.__name__} is not a class"
def test_sdk_exports_mcp_helpers():
from claude_agent_sdk import create_sdk_mcp_server, tool
assert callable(create_sdk_mcp_server)
assert callable(tool)
# ---------------------------------------------------------------------------
# ClaudeSDKClient interface
# ---------------------------------------------------------------------------
def test_client_has_required_methods():
from claude_agent_sdk import ClaudeSDKClient
required = ["connect", "disconnect", "query", "receive_messages"]
for name in required:
attr = getattr(ClaudeSDKClient, name, None)
assert attr is not None, f"ClaudeSDKClient.{name} missing"
assert callable(attr), f"ClaudeSDKClient.{name} is not callable"
def test_client_supports_async_context_manager():
from claude_agent_sdk import ClaudeSDKClient
assert hasattr(ClaudeSDKClient, "__aenter__")
assert hasattr(ClaudeSDKClient, "__aexit__")
# ---------------------------------------------------------------------------
# ClaudeAgentOptions fields
# ---------------------------------------------------------------------------
def test_agent_options_accepts_required_fields():
"""Verify ClaudeAgentOptions accepts all kwargs our code passes."""
from claude_agent_sdk import ClaudeAgentOptions
opts = ClaudeAgentOptions(
system_prompt="test",
cwd="/tmp",
)
assert opts.system_prompt == "test"
assert opts.cwd == "/tmp"
def test_agent_options_accepts_all_our_fields():
"""Comprehensive check of every field we use in service.py."""
from claude_agent_sdk import ClaudeAgentOptions
fields_we_use = [
"system_prompt",
"mcp_servers",
"allowed_tools",
"disallowed_tools",
"hooks",
"cwd",
"model",
"env",
"resume",
"max_buffer_size",
]
sig = inspect.signature(ClaudeAgentOptions)
for field in fields_we_use:
assert field in sig.parameters, (
f"ClaudeAgentOptions no longer accepts '{field}'"
f"available params: {list(sig.parameters.keys())}"
)
# ---------------------------------------------------------------------------
# Message attributes
# ---------------------------------------------------------------------------
def test_assistant_message_has_content_and_model():
from claude_agent_sdk import AssistantMessage, TextBlock
msg = AssistantMessage(content=[TextBlock(text="hi")], model="test")
assert hasattr(msg, "content")
assert hasattr(msg, "model")
def test_result_message_has_required_attrs():
from claude_agent_sdk import ResultMessage
msg = ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
)
assert msg.subtype == "success"
assert hasattr(msg, "result")
def test_system_message_has_subtype_and_data():
from claude_agent_sdk import SystemMessage
msg = SystemMessage(subtype="init", data={})
assert msg.subtype == "init"
assert msg.data == {}
def test_user_message_has_parent_tool_use_id():
from claude_agent_sdk import UserMessage
msg = UserMessage(content="test")
assert hasattr(msg, "parent_tool_use_id")
assert hasattr(msg, "tool_use_result")
def test_tool_use_block_has_id_name_input():
from claude_agent_sdk import ToolUseBlock
block = ToolUseBlock(id="t1", name="test", input={"key": "val"})
assert block.id == "t1"
assert block.name == "test"
assert block.input == {"key": "val"}
def test_tool_result_block_has_required_attrs():
from claude_agent_sdk import ToolResultBlock
block = ToolResultBlock(tool_use_id="t1", content="result")
assert block.tool_use_id == "t1"
assert block.content == "result"
assert hasattr(block, "is_error")
# ---------------------------------------------------------------------------
# Hook types
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"hook_event",
["PreToolUse", "PostToolUse", "Stop"],
)
def test_sdk_exports_hook_event_type(hook_event: str):
"""Verify HookEvent literal includes the events our security_hooks use."""
from claude_agent_sdk.types import HookEvent
# HookEvent is a Literal type — check that our events are valid values.
# We can't easily inspect Literal at runtime, so just verify the type exists.
assert HookEvent is not None

View File

@@ -246,15 +246,33 @@ def create_security_hooks(
"""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
is_builtin = not tool_name.startswith(MCP_TOOL_PREFIX)
logger.info(
"[SDK] PostToolUse: %s (builtin=%s, tool_use_id=%s)",
tool_name,
is_builtin,
(tool_use_id or "")[:12],
)
# Stash output for SDK built-in tools so the response adapter can
# emit StreamToolOutputAvailable even when the CLI doesn't surface
# a separate UserMessage with ToolResultBlock content.
if not tool_name.startswith(MCP_TOOL_PREFIX):
if is_builtin:
tool_response = input_data.get("tool_response")
if tool_response is not None:
resp_preview = str(tool_response)[:100]
logger.info(
"[SDK] Stashing builtin output for %s (%d chars): %s...",
tool_name,
len(str(tool_response)),
resp_preview,
)
stash_pending_tool_output(tool_name, tool_response)
else:
logger.warning(
"[SDK] PostToolUse for builtin %s but tool_response is None",
tool_name,
)
return cast(SyncHookJSONOutput, {})

View File

@@ -9,6 +9,7 @@ from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any
from backend.data.redis_client import get_redis_async
from backend.util.exceptions import NotFoundError
from .. import stream_registry
@@ -24,6 +25,7 @@ from ..response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
@@ -46,6 +48,7 @@ from .tool_adapter import (
LongRunningCallback,
create_copilot_mcp_server,
set_execution_context,
wait_for_stash,
)
from .transcript import (
cleanup_cli_project_dir,
@@ -130,8 +133,65 @@ is delivered to the user via a background stream.
All tasks must run in the foreground.
"""
# Session streaming lock configuration
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
STREAM_LOCK_TTL = 3600 # 1 hour - matches stream_ttl
def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
async def _acquire_stream_lock(session_id: str, stream_id: str) -> bool:
"""Acquire an exclusive lock for streaming to this session.
Prevents multiple concurrent streams to the same session which can cause:
- Message duplication
- Race conditions in message saves
- Confusing UX with multiple AI responses
Returns:
True if lock was acquired, False if another stream is active.
"""
redis = await get_redis_async()
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
# SET NX EX - atomic "set if not exists" with expiry
result = await redis.set(lock_key, stream_id, ex=STREAM_LOCK_TTL, nx=True)
return result is not None
async def _release_stream_lock(session_id: str, stream_id: str) -> None:
"""Release the stream lock if we still own it.
Only releases the lock if the stored stream_id matches ours (prevents
releasing another stream's lock if we somehow timed out).
"""
redis = await get_redis_async()
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
# Lua script for atomic compare-and-delete (only delete if value matches)
script = """
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
else
return 0
end
"""
await redis.eval(script, 1, lock_key, stream_id) # type: ignore[misc]
async def check_active_stream(session_id: str) -> str | None:
"""Check if a stream is currently active for this session.
Returns:
The active stream_id if one exists, None otherwise.
"""
redis = await get_redis_async()
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
active_stream = await redis.get(lock_key)
return active_stream.decode() if isinstance(active_stream, bytes) else active_stream
def _build_long_running_callback(
user_id: str | None,
saved_msg_count_ref: list[int] | None = None,
) -> LongRunningCallback:
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
@@ -140,6 +200,12 @@ def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
page refreshes / pod restarts, and the frontend shows the proper loading
widget with progress updates.
Args:
user_id: User ID for the session
saved_msg_count_ref: Mutable reference [count] shared with streaming loop
for coordinating message saves. When provided, the callback will update
it after appending messages to prevent counter drift.
The returned callback matches the ``LongRunningCallback`` signature:
``(tool_name, args, session) -> MCP response dict``.
"""
@@ -205,7 +271,11 @@ def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
tool_call_id=tool_call_id,
)
session.messages.append(pending_message)
await upsert_chat_session(session)
# Collision detection happens in add_chat_messages_batch (db.py)
_, final_count = await upsert_chat_session(session)
# Update shared counter so streaming loop stays in sync
if saved_msg_count_ref is not None:
saved_msg_count_ref[0] = final_count
# --- Spawn background task (reuses non-SDK infrastructure) ---
bg_task = asyncio.create_task(
@@ -344,15 +414,15 @@ async def _compress_conversation_history(
Returns the compressed prior messages (everything except the current message).
"""
prior = session.messages[:-1]
if len(prior) < 2:
return prior
messages = session.messages[:-1]
if len(messages) < 2:
return messages
from backend.util.prompt import compress_context
# Convert ChatMessages to dicts for compress_context
messages_dict = []
for msg in prior:
for msg in messages:
msg_dict: dict[str, Any] = {"role": msg.role}
if msg.content:
msg_dict["content"] = msg.content
@@ -400,7 +470,7 @@ async def _compress_conversation_history(
for m in result.messages
]
return prior
return messages
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
@@ -442,8 +512,8 @@ def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
def _is_tool_error_or_denial(content: str | None) -> bool:
"""Check if a tool message content indicates an error or denial.
We include these in conversation context so the agent doesn't
hallucinate success for operations that actually failed.
Currently unused — ``_format_conversation_context`` includes all tool
results. Kept as a utility for future selective filtering.
"""
if not content:
return False
@@ -458,7 +528,7 @@ def _is_tool_error_or_denial(content: str | None) -> bool:
"maximum", # subtask-limit denial
"denied",
"blocked",
"failed", # internal tool execution failures
"failed to", # internal tool execution failures
'"iserror": true', # MCP protocol error flag
)
)
@@ -540,7 +610,7 @@ async def stream_chat_completion_sdk(
user_id=user_id, session_id=session_id, message_length=len(message)
)
session = await upsert_chat_session(session)
session, _ = await upsert_chat_session(session)
# Generate title for new sessions (first user message)
if is_user_message and not session.title:
@@ -562,6 +632,23 @@ async def stream_chat_completion_sdk(
system_prompt += _SDK_TOOL_SUPPLEMENT
message_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
stream_id = task_id # Use task_id as unique stream identifier
# Acquire stream lock to prevent concurrent streams to the same session
lock_acquired = await _acquire_stream_lock(session_id, stream_id)
if not lock_acquired:
# Another stream is active - check if it's still alive
active_stream = await check_active_stream(session_id)
logger.warning(
f"[SDK] Session {session_id} already has an active stream: {active_stream}"
)
yield StreamError(
errorText="Another stream is already active for this session. "
"Please wait for it to complete or refresh the page.",
code="stream_already_active",
)
yield StreamFinish()
return
yield StreamStart(messageId=message_id, taskId=task_id)
@@ -579,10 +666,16 @@ async def stream_chat_completion_sdk(
sdk_cwd = _make_sdk_cwd(session_id)
os.makedirs(sdk_cwd, exist_ok=True)
# Initialize saved message counter as mutable list so long-running
# callback and streaming loop can coordinate
saved_msg_count_ref: list[int] = [len(session.messages)]
set_execution_context(
user_id,
session,
long_running_callback=_build_long_running_callback(user_id),
long_running_callback=_build_long_running_callback(
user_id, saved_msg_count_ref
),
)
try:
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
@@ -674,7 +767,7 @@ async def stream_chat_completion_sdk(
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type]
adapter = SDKResponseAdapter(message_id=message_id)
adapter = SDKResponseAdapter(message_id=message_id, session_id=session_id)
adapter.set_task_id(task_id)
async with ClaudeSDKClient(options=options) as client:
@@ -699,10 +792,13 @@ async def stream_chat_completion_sdk(
transcript_msg_count,
session_id,
)
logger.info(
f"[SDK] Sending query ({len(session.messages)} msgs, "
f"resume={use_resume})"
"[SDK] [%s] Sending query — resume=%s, "
"total_msgs=%d, query_len=%d",
session_id[:12],
use_resume,
len(session.messages),
len(query_message),
)
await client.query(query_message, session_id=session_id)
@@ -710,98 +806,293 @@ async def stream_chat_completion_sdk(
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
# Track persisted message count. Uses shared ref so long-running
# callback can update it for coordination
# Use an explicit async iterator with timeout to send
# heartbeats when the CLI is idle (e.g. executing tools).
# This prevents proxies/LBs from closing the SSE connection.
# asyncio.timeout() is preferred over asyncio.wait_for()
# because wait_for wraps in a separate Task whose cancellation
# can leave the async generator in a broken state.
# Use an explicit async iterator with non-cancelling heartbeats.
# CRITICAL: we must NOT cancel __anext__() mid-flight — doing so
# (via asyncio.timeout or wait_for) corrupts the SDK's internal
# anyio memory stream, causing StopAsyncIteration on the next
# call and silently dropping all in-flight tool results.
# Instead, wrap __anext__() in a Task and use asyncio.wait()
# with a timeout. On timeout we emit a heartbeat but keep the
# Task alive so it can deliver the next message.
msg_iter = client.receive_messages().__aiter__()
while not stream_completed:
try:
async with asyncio.timeout(_HEARTBEAT_INTERVAL):
sdk_msg = await msg_iter.__anext__()
except TimeoutError:
yield StreamHeartbeat()
continue
except StopAsyncIteration:
break
pending_task: asyncio.Task[Any] | None = None
try:
while not stream_completed:
if pending_task is None:
logger.debug(
f"[SDK] Received: {type(sdk_msg).__name__} "
f"{getattr(sdk_msg, 'subtype', '')}"
)
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
async def _next_msg() -> Any:
return await msg_iter.__anext__()
pending_task = asyncio.create_task(_next_msg())
done, _ = await asyncio.wait(
{pending_task}, timeout=_HEARTBEAT_INTERVAL
)
if not done:
# Timeout — emit heartbeat but keep the task alive
yield StreamHeartbeat()
continue
# Log tool events for debugging visibility issues
# Task completed — get result
pending_task = None
try:
sdk_msg = done.pop().result()
except StopAsyncIteration:
logger.info(
"[SDK] [%s] Stream ended normally "
"(StopAsyncIteration)",
session_id[:12],
)
break
except Exception as stream_err:
# SDK sends {"type": "error"} which raises
# Exception in receive_messages() — capture it
# so the session can still be saved and the
# frontend gets a clean finish.
logger.error(
"[SDK] [%s] Stream error from SDK: %s",
session_id[:12],
stream_err,
exc_info=True,
)
yield StreamError(
errorText=f"SDK stream error: {stream_err}",
code="sdk_stream_error",
)
break
logger.info(
"[SDK] [%s] Received: %s %s "
"(unresolved=%d, current=%d, resolved=%d)",
session_id[:12],
type(sdk_msg).__name__,
getattr(sdk_msg, "subtype", ""),
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
len(adapter.current_tool_calls),
len(adapter.resolved_tool_calls),
)
# Race-condition fix: SDK hooks (PostToolUse) are
# executed asynchronously via start_soon() — the next
# message can arrive before the hook stashes output.
# wait_for_stash() awaits an asyncio.Event signaled by
# stash_pending_tool_output(), completing as soon as
# the hook finishes (typically <1ms). The sleep(0)
# after lets any remaining concurrent hooks complete.
#
# Skip for parallel tool continuations: when the SDK
# sends parallel tool calls as separate
# AssistantMessages (each containing only
# ToolUseBlocks), we must NOT wait/flush — the prior
# tools are still executing concurrently.
from claude_agent_sdk import (
AssistantMessage,
ResultMessage,
ToolUseBlock,
)
is_parallel_continuation = isinstance(
sdk_msg, AssistantMessage
) and all(isinstance(b, ToolUseBlock) for b in sdk_msg.content)
if (
adapter.has_unresolved_tool_calls
and isinstance(sdk_msg, (AssistantMessage, ResultMessage))
and not is_parallel_continuation
):
if await wait_for_stash(timeout=0.5):
await asyncio.sleep(0)
else:
logger.warning(
"[SDK] [%s] Timed out waiting for "
"PostToolUse hook stash "
"(%d unresolved tool calls)",
session_id[:12],
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
)
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
# Log tool events for debugging
if isinstance(
response,
(
StreamToolInputAvailable,
StreamToolOutputAvailable,
),
):
extra = ""
if isinstance(response, StreamToolOutputAvailable):
out_len = len(str(response.output))
extra = f", output_len={out_len}"
logger.info(
"[SDK] [%s] Tool event: %s, tool=%s%s",
session_id[:12],
type(response).__name__,
getattr(response, "toolName", "N/A"),
extra,
)
yield response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
# After tool results, start a new assistant
# message for the post-tool text.
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
accumulated_tool_calls = []
has_appended_assistant = False
has_tool_results = False
session.messages.append(assistant_response)
has_appended_assistant = True
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(
response.input or {}
),
},
}
)
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
# Save before tool execution starts so the
# pending tool call is visible on refresh /
# other devices. Collision detection happens
# in add_chat_messages_batch (db.py).
try:
_, final_count = await upsert_chat_session(session)
# Update shared ref so callback stays in sync
saved_msg_count_ref[0] = final_count
except Exception as save_err:
logger.warning(
"[SDK] [%s] Incremental save " "failed: %s",
session_id[:12],
save_err,
)
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
tool_call_id=response.toolCallId,
)
)
has_tool_results = True
# Save after tool completes so the result is
# visible on refresh / other devices.
# Collision detection happens in add_chat_messages_batch (db.py).
try:
_, final_count = await upsert_chat_session(session)
# Update shared ref so callback stays in sync
saved_msg_count_ref[0] = final_count
except Exception as save_err:
logger.warning(
"[SDK] [%s] Incremental save " "failed: %s",
session_id[:12],
save_err,
)
elif isinstance(response, StreamFinish):
stream_completed = True
except asyncio.CancelledError:
# Task/generator was cancelled (e.g. client disconnect,
# server shutdown). Log and let the safety-net / finally
# blocks handle cleanup.
logger.warning(
"[SDK] [%s] Streaming loop cancelled "
"(asyncio.CancelledError)",
session_id[:12],
)
raise
finally:
# Cancel the pending __anext__ task to avoid a leaked
# coroutine. This is safe even if the task already
# completed.
if pending_task is not None and not pending_task.done():
pending_task.cancel()
try:
await pending_task
except (asyncio.CancelledError, StopAsyncIteration):
pass
# Safety net: if tools are still unresolved after the
# streaming loop (e.g. StopAsyncIteration before ResultMessage,
# or SDK not sending UserMessages for built-in tools), flush
# them now so the frontend stops showing spinners.
if adapter.has_unresolved_tool_calls:
logger.warning(
"[SDK] [%s] %d unresolved tool(s) after stream loop — "
"flushing as safety net",
session_id[:12],
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
)
safety_responses: list[StreamBaseResponse] = []
adapter._flush_unresolved_tool_calls(safety_responses)
for response in safety_responses:
if isinstance(
response,
(StreamToolInputAvailable, StreamToolOutputAvailable),
):
logger.info(
"[SDK] Tool event: %s, tool=%s",
"[SDK] [%s] Safety flush: %s, tool=%s",
session_id[:12],
type(response).__name__,
getattr(response, "toolName", "N/A"),
)
yield response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
# After tool results, start a new assistant
# message for the post-tool text.
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
accumulated_tool_calls = []
has_appended_assistant = False
has_tool_results = False
session.messages.append(assistant_response)
has_appended_assistant = True
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(response.input or {}),
},
}
)
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
tool_call_id=response.toolCallId,
)
)
has_tool_results = True
elif isinstance(response, StreamFinish):
stream_completed = True
# If the stream ended without a ResultMessage (no
# StreamFinish), the SDK CLI exited unexpectedly. Close
# the open step and emit StreamFinish so the frontend
# transitions to the "ready" state.
if not stream_completed:
logger.warning(
"[SDK] [%s] Stream ended without ResultMessage "
"(StopAsyncIteration) — emitting StreamFinish",
session_id[:12],
)
if adapter.step_open:
yield StreamFinishStep()
adapter.step_open = False
closing_responses: list[StreamBaseResponse] = []
adapter._end_text_if_open(closing_responses)
for r in closing_responses:
yield r
yield StreamFinish()
stream_completed = True
if (
assistant_response.content or assistant_response.tool_calls
@@ -856,13 +1147,22 @@ async def stream_chat_completion_sdk(
"to use the OpenAI-compatible fallback."
)
await asyncio.shield(upsert_chat_session(session))
logger.debug(
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
_, final_count = await asyncio.shield(upsert_chat_session(session))
logger.info(
"[SDK] [%s] Session saved with %d messages (DB count: %d)",
session_id[:12],
len(session.messages),
final_count,
)
if not stream_completed:
yield StreamFinish()
except asyncio.CancelledError:
# Client disconnect / server shutdown — log but re-raise so
# the framework can clean up. The finally block still runs
# for transcript upload.
logger.warning("[SDK] [%s] Session cancelled (CancelledError)", session_id[:12])
raise
except Exception as e:
logger.error(f"[SDK] Error: {e}", exc_info=True)
try:
@@ -910,6 +1210,9 @@ async def stream_chat_completion_sdk(
if sdk_cwd:
_cleanup_sdk_tool_results(sdk_cwd)
# Release stream lock to allow new streams for this session
await _release_stream_lock(session_id, stream_id)
async def _try_upload_transcript(
user_id: str,

View File

@@ -9,6 +9,7 @@ via a callback provided by the service layer. This avoids wasteful SDK polling
and makes results survive page refreshes.
"""
import asyncio
import itertools
import json
import logging
@@ -44,6 +45,14 @@ _current_session: ContextVar[ChatSession | None] = ContextVar(
_pending_tool_outputs: ContextVar[dict[str, list[str]]] = ContextVar(
"pending_tool_outputs", default=None # type: ignore[arg-type]
)
# Event signaled whenever stash_pending_tool_output() adds a new entry.
# Used by the streaming loop to wait for PostToolUse hooks to complete
# instead of sleeping an arbitrary duration. The SDK fires hooks via
# start_soon (fire-and-forget) so the next message can arrive before
# the hook stashes its output — this event bridges that gap.
_stash_event: ContextVar[asyncio.Event | None] = ContextVar(
"_stash_event", default=None
)
# Callback type for delegating long-running tools to the non-SDK infrastructure.
# Args: (tool_name, arguments, session) → MCP-formatted response dict.
@@ -76,6 +85,7 @@ def set_execution_context(
_current_user_id.set(user_id)
_current_session.set(session)
_pending_tool_outputs.set({})
_stash_event.set(asyncio.Event())
_long_running_callback.set(long_running_callback)
@@ -134,6 +144,43 @@ def stash_pending_tool_output(tool_name: str, output: Any) -> None:
except (TypeError, ValueError):
text = str(output)
pending.setdefault(tool_name, []).append(text)
# Signal any waiters that new output is available.
event = _stash_event.get(None)
if event is not None:
event.set()
async def wait_for_stash(timeout: float = 0.5) -> bool:
"""Wait for a PostToolUse hook to stash tool output.
The SDK fires PostToolUse hooks asynchronously via ``start_soon()`` —
the next message (AssistantMessage/ResultMessage) can arrive before the
hook completes and stashes its output. This function bridges that gap
by waiting on the ``_stash_event``, which is signaled by
:func:`stash_pending_tool_output`.
After the event fires, callers should ``await asyncio.sleep(0)`` to
give any remaining concurrent hooks a chance to complete.
Returns ``True`` if a stash signal was received, ``False`` on timeout.
The timeout is a safety net — normally the stash happens within
microseconds of yielding to the event loop.
"""
event = _stash_event.get(None)
if event is None:
return False
# Fast path: hook already completed before we got here.
if event.is_set():
event.clear()
return True
# Slow path: wait for the hook to signal.
try:
async with asyncio.timeout(timeout):
await event.wait()
event.clear()
return True
except TimeoutError:
return False
async def _execute_tool_sync(

View File

@@ -352,7 +352,8 @@ async def assign_user_to_session(
if not session:
raise NotFoundError(f"Session {session_id} not found")
session.user_id = user_id
return await upsert_chat_session(session)
session, _ = await upsert_chat_session(session)
return session
async def stream_chat_completion(
@@ -463,7 +464,7 @@ async def stream_chat_completion(
)
upsert_start = time.monotonic()
session = await upsert_chat_session(session)
session, _ = await upsert_chat_session(session)
upsert_time = (time.monotonic() - upsert_start) * 1000
logger.info(
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
@@ -689,7 +690,7 @@ async def stream_chat_completion(
f"tool_responses={len(tool_response_messages)}"
)
if messages_to_save_early or has_appended_streaming_message:
await upsert_chat_session(session)
_ = await upsert_chat_session(session)
has_saved_assistant_message = True
has_yielded_end = True
@@ -728,7 +729,7 @@ async def stream_chat_completion(
if tool_response_messages:
session.messages.extend(tool_response_messages)
try:
await upsert_chat_session(session)
_ = await upsert_chat_session(session)
except Exception as e:
logger.warning(
f"Failed to save interrupted session {session.session_id}: {e}"
@@ -769,7 +770,7 @@ async def stream_chat_completion(
if messages_to_save:
session.messages.extend(messages_to_save)
if messages_to_save or has_appended_streaming_message:
await upsert_chat_session(session)
_ = await upsert_chat_session(session)
if not has_yielded_error:
error_message = str(e)
@@ -853,7 +854,7 @@ async def stream_chat_completion(
not has_long_running_tool_call
and (messages_to_save or has_appended_streaming_message)
):
await upsert_chat_session(session)
_ = await upsert_chat_session(session)
else:
logger.info(
"Assistant message already saved when StreamFinish was received, "
@@ -1525,7 +1526,7 @@ async def _yield_tool_call(
tool_call_id=tool_call_id,
)
session.messages.append(pending_message)
await upsert_chat_session(session)
_ = await upsert_chat_session(session)
await _with_optional_lock(session_lock, _save_pending)
logger.info(
@@ -1563,7 +1564,11 @@ async def _yield_tool_call(
await _mark_operation_completed(tool_call_id)
# Mark stream registry task as failed if it was created
try:
await stream_registry.mark_task_completed(task_id, status="failed")
await stream_registry.mark_task_completed(
task_id,
status="failed",
error_message=f"Failed to setup tool {tool_name}: {e}",
)
except Exception as mark_err:
logger.warning(f"Failed to mark task {task_id} as failed: {mark_err}")
logger.error(
@@ -1731,7 +1736,11 @@ async def _execute_long_running_tool_with_streaming(
session = await get_chat_session(session_id, user_id)
if not session:
logger.error(f"Session {session_id} not found for background tool")
await stream_registry.mark_task_completed(task_id, status="failed")
await stream_registry.mark_task_completed(
task_id,
status="failed",
error_message=f"Session {session_id} not found",
)
return
# Pass operation_id and task_id to the tool for async processing
@@ -2011,7 +2020,7 @@ async def _generate_llm_continuation(
fresh_session.messages.append(assistant_message)
# Save to database (not cache) to persist the response
await upsert_chat_session(fresh_session)
_ = await upsert_chat_session(fresh_session)
# Invalidate cache so next poll/refresh gets fresh data
await invalidate_session_cache(session_id)
@@ -2217,7 +2226,7 @@ async def _generate_llm_continuation_with_streaming(
fresh_session.messages.append(assistant_message)
# Save to database (not cache) to persist the response
await upsert_chat_session(fresh_session)
_ = await upsert_chat_session(fresh_session)
# Invalidate cache so next poll/refresh gets fresh data
await invalidate_session_cache(session_id)

View File

@@ -58,7 +58,7 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
session, _ = await upsert_chat_session(session)
has_errors = False
has_ended = False
@@ -104,7 +104,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
session, _ = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---
keyword = "ZEPHYR42"

View File

@@ -644,6 +644,8 @@ async def _stream_listener(
async def mark_task_completed(
task_id: str,
status: Literal["completed", "failed"] = "completed",
*,
error_message: str | None = None,
) -> bool:
"""Mark a task as completed and publish finish event.
@@ -654,6 +656,10 @@ async def mark_task_completed(
Args:
task_id: Task ID to mark as completed
status: Final status ("completed" or "failed")
error_message: If provided and status="failed", publish a StreamError
before StreamFinish so connected clients see why the task ended.
If not provided, no StreamError is published (caller should publish
manually if needed to avoid duplicates).
Returns:
True if task was newly marked completed, False if already completed/failed
@@ -669,6 +675,17 @@ async def mark_task_completed(
logger.debug(f"Task {task_id} already completed/failed, skipping")
return False
# Publish error event before finish so connected clients know WHY the
# task ended. Only publish if caller provided an explicit error message
# to avoid duplicates with code paths that manually publish StreamError.
# This is best-effort — if it fails, the StreamFinish still ensures
# listeners clean up.
if status == "failed" and error_message:
try:
await publish_chunk(task_id, StreamError(errorText=error_message))
except Exception as e:
logger.warning(f"Failed to publish error event for task {task_id}: {e}")
# THEN publish finish event (best-effort - listeners can detect via status polling)
try:
await publish_chunk(task_id, StreamFinish())
@@ -821,27 +838,6 @@ async def get_active_task_for_session(
if task_user_id and user_id != task_user_id:
continue
# Auto-expire stale tasks that exceeded stream_timeout
created_at_str = meta.get("created_at", "")
if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str)
age_seconds = (
datetime.now(timezone.utc) - created_at
).total_seconds()
if age_seconds > config.stream_timeout:
logger.warning(
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
)
await mark_task_completed(task_id, "failed")
continue
except (ValueError, TypeError) as exc:
logger.warning(
f"[TASK_LOOKUP] Failed to parse created_at "
f"for task {task_id[:8]}...: {exc}"
)
logger.info(
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
)

View File

@@ -303,7 +303,7 @@ class DatabaseManager(AppService):
get_user_chat_sessions = _(chat_db.get_user_chat_sessions)
get_user_session_count = _(chat_db.get_user_session_count)
delete_chat_session = _(chat_db.delete_chat_session)
get_chat_session_message_count = _(chat_db.get_chat_session_message_count)
get_next_sequence = _(chat_db.get_next_sequence)
update_tool_message_content = _(chat_db.update_tool_message_content)
@@ -473,5 +473,5 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_user_chat_sessions = d.get_user_chat_sessions
get_user_session_count = d.get_user_session_count
delete_chat_session = d.delete_chat_session
get_chat_session_message_count = d.get_chat_session_message_count
get_next_sequence = d.get_next_sequence
update_tool_message_content = d.update_tool_message_content

View File

@@ -899,17 +899,17 @@ files = [
[[package]]
name = "claude-agent-sdk"
version = "0.1.35"
version = "0.1.39"
description = "Python SDK for Claude Code"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "claude_agent_sdk-0.1.35-py3-none-macosx_11_0_arm64.whl", hash = "sha256:df67f4deade77b16a9678b3a626c176498e40417f33b04beda9628287f375591"},
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:14963944f55ded7c8ed518feebfa5b4284aa6dd8d81aeff2e5b21a962ce65097"},
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:84344dcc535d179c1fc8a11c6f34c37c3b583447bdf09d869effb26514fd7a65"},
{file = "claude_agent_sdk-0.1.35-py3-none-win_amd64.whl", hash = "sha256:1b3d54b47448c93f6f372acd4d1757f047c3c1e8ef5804be7a1e3e53e2c79a5f"},
{file = "claude_agent_sdk-0.1.35.tar.gz", hash = "sha256:0f98e2b3c71ca85abfc042e7a35c648df88e87fda41c52e6779ef7b038dcbb52"},
{file = "claude_agent_sdk-0.1.39-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6ed6a79781f545b761b9fe467bc5ae213a103c9d3f0fe7a9dad3c01790ed58fa"},
{file = "claude_agent_sdk-0.1.39-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:0c03b5a3772eaec42e29ea39240c7d24b760358082f2e36336db9e71dde3dda4"},
{file = "claude_agent_sdk-0.1.39-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:d2665c9e87b6ffece590bcdd6eb9def47cde4809b0d2f66e0a61a719189be7c9"},
{file = "claude_agent_sdk-0.1.39-py3-none-win_amd64.whl", hash = "sha256:d03324daf7076be79d2dd05944559aabf4cc11c98d3a574b992a442a7c7a26d6"},
{file = "claude_agent_sdk-0.1.39.tar.gz", hash = "sha256:dcf0ebd5a638c9a7d9f3af7640932a9212b2705b7056e4f08bd3968a865b4268"},
]
[package.dependencies]
@@ -8530,4 +8530,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.14"
content-hash = "55e095de555482f0fe47de7695f390fe93e7bcf739b31c391b2e5e3c3d938ae3"
content-hash = "3ef62836d8321b9a3b8e897dade8dc6ca9022fd9468c53f384b0871b521ab343"

View File

@@ -16,7 +16,7 @@ anthropic = "^0.79.0"
apscheduler = "^3.11.1"
autogpt-libs = { path = "../autogpt_libs", develop = true }
bleach = { extras = ["css"], version = "^6.2.0" }
claude-agent-sdk = "^0.1.0"
claude-agent-sdk = "^0.1.39" # see copilot/sdk/sdk_compat_test.py for capability checks
click = "^8.2.0"
cryptography = "^46.0"
discord-py = "^2.5.2"

View File

@@ -58,6 +58,7 @@ function toToolInput(rawArguments: unknown): unknown {
export function convertChatSessionMessagesToUiMessages(
sessionId: string,
rawMessages: unknown[],
options?: { isComplete?: boolean },
): UIMessage<unknown, UIDataTypes, UITools>[] {
const messages = coerceSessionChatMessages(rawMessages);
const toolOutputsByCallId = new Map<string, unknown>();
@@ -104,6 +105,16 @@ export function convertChatSessionMessagesToUiMessages(
input,
output: typeof output === "string" ? safeJsonParse(output) : output,
});
} else if (options?.isComplete) {
// Session is complete (no active stream) but this tool call has
// no output in the DB — mark as completed to stop stale spinners.
parts.push({
type: `tool-${toolName}`,
toolCallId,
state: "output-available",
input,
output: "",
});
} else {
parts.push({
type: `tool-${toolName}`,

View File

@@ -40,16 +40,6 @@ export function useChatSession() {
}
}, [sessionId, queryClient]);
// Memoize so the effect in useCopilotPage doesn't infinite-loop on a new
// array reference every render. Re-derives only when query data changes.
const hydratedMessages = useMemo(() => {
if (sessionQuery.data?.status !== 200 || !sessionId) return undefined;
return convertChatSessionMessagesToUiMessages(
sessionId,
sessionQuery.data.data.messages ?? [],
);
}, [sessionQuery.data, sessionId]);
// Expose active_stream info so the caller can trigger manual resume
// after hydration completes (rather than relying on AI SDK's built-in
// resume which fires before hydration).
@@ -58,6 +48,19 @@ export function useChatSession() {
return !!sessionQuery.data.data.active_stream;
}, [sessionQuery.data]);
// Memoize so the effect in useCopilotPage doesn't infinite-loop on a new
// array reference every render. Re-derives only when query data changes.
// When the session is complete (no active stream), mark dangling tool
// calls as completed so stale spinners don't persist after refresh.
const hydratedMessages = useMemo(() => {
if (sessionQuery.data?.status !== 200 || !sessionId) return undefined;
return convertChatSessionMessagesToUiMessages(
sessionId,
sessionQuery.data.data.messages ?? [],
{ isComplete: !hasActiveStream },
);
}, [sessionQuery.data, sessionId, hasActiveStream]);
const { mutateAsync: createSessionMutation, isPending: isCreatingSession } =
usePostV2CreateSession({
mutation: {

View File

@@ -1,4 +1,5 @@
import {
getGetV2GetSessionQueryKey,
getGetV2ListSessionsQueryKey,
postV2CancelSessionTask,
useDeleteV2DeleteSession,
@@ -187,11 +188,35 @@ export function useCopilotPage() {
});
}, [hydratedMessages, setMessages, status]);
// Ref: tracks whether we've already resumed for a given session.
// Reset when the stream ends so re-resume is possible if the backend
// task is still running (SSE dropped but executor didn't finish).
const hasResumedRef = useRef<string | null>(null);
// When the stream ends (or drops), invalidate the session cache so the
// next hydration fetches fresh messages from the backend. Without this,
// staleTime: Infinity means the cache keeps the pre-stream data forever,
// and any messages added during streaming are lost on remount/navigation.
const prevStatusRef = useRef(status);
useEffect(() => {
const prev = prevStatusRef.current;
prevStatusRef.current = status;
const wasActive = prev === "streaming" || prev === "submitted";
const isIdle = status === "ready" || status === "error";
if (wasActive && isIdle && sessionId) {
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId),
});
// Allow re-resume if the backend task is still running.
hasResumedRef.current = null;
}
}, [status, sessionId, queryClient]);
// Resume an active stream AFTER hydration completes.
// The backend returns active_stream info when a task is still running.
// We wait for hydration so the AI SDK has the conversation history
// before the resumed stream appends the in-progress assistant message.
const hasResumedRef = useRef<string | null>(null);
useEffect(() => {
if (!hasActiveStream || !sessionId) return;
if (!hydratedMessages || hydratedMessages.length === 0) return;
@@ -202,18 +227,6 @@ export function useCopilotPage() {
resumeStream();
}, [hasActiveStream, sessionId, hydratedMessages, status, resumeStream]);
// When the stream finishes, resolve any tool parts still showing spinners.
// This can happen if the backend didn't emit StreamToolOutputAvailable for
// a tool call before sending StreamFinish (e.g. SDK built-in tools).
const prevStatusRef = useRef(status);
useEffect(() => {
const prev = prevStatusRef.current;
prevStatusRef.current = status;
if (prev === "streaming" && status === "ready") {
setMessages((msgs) => resolveInProgressTools(msgs, "completed"));
}
}, [status, setMessages]);
// Poll session endpoint when a long-running tool (create_agent, edit_agent)
// is in progress. When the backend completes, the session data will contain
// the final tool output — this hook detects the change and updates messages.