Compare commits

..

27 Commits

Author SHA1 Message Date
Zamil Majdy
9106f0b5ce poetry lock 2026-02-20 15:11:46 +07:00
Zamil Majdy
37355f7581 fix(copilot): non-cancelling heartbeat, incremental saves, frontend reconnection
- Replace asyncio.timeout() with asyncio.wait() for SDK message iteration
  to avoid corrupting the internal anyio stream on timeout (root cause of
  tool outputs getting stuck)
- Add CancelledError handling + pending task cleanup in finally block
- Fix _end_text_if_open([]) discarding StreamTextEnd events (Sentry bug)
- Save session to DB after each tool input/output for cross-device recovery
- Optimize incremental saves by passing existing_message_count to skip
  redundant DB count queries
- Frontend: invalidate session cache + reset resume ref on stream end
  so SSE reconnection works after drops
2026-02-20 15:07:18 +07:00
Zamil Majdy
e1e3b6094e poetry lock 2026-02-20 15:07:18 +07:00
Zamil Majdy
e18b3c561f Merge branch 'dev' into fix/messed-up-copilot 2026-02-20 11:51:49 +05:30
Zamil Majdy
d937c6839a fix(copilot): handle stream ending without text + PostToolUse logging
When the SDK CLI exits without sending a ResultMessage (parallel tool
execution), the frontend never gets StreamFinish and tools appear stuck.
Now detect StopAsyncIteration and emit StreamFinish as a fallback.

Also add INFO-level PostToolUse hook logging to trace whether the SDK
fires hooks and stashes output for built-in tools like WebSearch.
2026-02-20 13:12:05 +07:00
Zamil Majdy
8c2363ea88 fix(copilot): add safety-net flush and diagnostic logging for parallel tools
WebSearch/web_fetch parallel tool calls end with spinners resolving but no
output shown, then the session ends with no text response at all. Add:

- Safety-net flush after streaming loop for any unresolved tools
- INFO-level logging for every SDK message (type, unresolved count)
- UserMessage block detail logging to trace tool result delivery
- Flush-called-but-empty logging to detect already-resolved-elsewhere
2026-02-20 13:07:42 +07:00
Zamil Majdy
a408b45542 fix(copilot): don't flush parallel tool calls prematurely
The SDK sends parallel tool calls as separate AssistantMessages each
containing only ToolUseBlocks.  The flush logic treated each new
AssistantMessage as a new turn and prematurely emitted empty output for
prior tools, causing spinners to disappear and the stream to appear
stuck.

Skip flush and wait_for_stash when the incoming AssistantMessage is a
parallel continuation (contains only ToolUseBlocks).  Also prevent
duplicate StreamToolOutputAvailable for already-resolved tool calls.
2026-02-20 11:43:44 +07:00
Zamil Majdy
3a38b5e9bd fix(copilot): address review comments — wait_for_stash fast path, error marker, compat test
- Add fast path in wait_for_stash: check event.is_set() before clearing
  to avoid unnecessary 0.5s timeout when PostToolUse hook completes
  before the streaming loop calls wait_for_stash
- Tighten "failed" error marker to "failed to" in _is_tool_error_or_denial
  to avoid matching benign outputs like "3 tests failed"
- Add max_buffer_size to SDK compat test fields_we_use
2026-02-20 11:18:49 +07:00
Zamil Majdy
3491365b45 poetry lock 2026-02-20 11:15:03 +07:00
Zamil Majdy
d3299cfd7f fix(copilot): remove redundant resolveInProgressTools streaming→ready effect
The backend wait_for_stash() + _flush_unresolved_tool_calls() already
ensures all tool calls are resolved before StreamFinish. The useEffect
that called resolveInProgressTools on streaming→ready was a frontend
safety net for the same issue — no longer needed.

Keep the function itself for stop() (user cancellation).
2026-02-20 10:52:32 +07:00
Zamil Majdy
9161090944 Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into fix/messed-up-copilot
Resolve conflict in service.py:
- Take dev's _build_query_message() refactor
- Restore _compress_conversation_history (dev's signature)
- Keep _is_tool_error_or_denial (tested in dev, harmless)
- Drop redundant inline query-building and approach logging
2026-02-20 10:48:20 +07:00
Zamil Majdy
372f9bff32 fix(copilot): address review comments — SDK compat test, output_len, error marker
- Add sdk_compat_test.py (17 tests) verifying the claude-agent-sdk public
  API surface we depend on, replacing the need for a tight version pin.
- Fix output_len logging: use len(str(...)) so dict outputs report
  serialized size, not key count.
- Tighten "failed" error marker to "failed to" to avoid false positives
  on benign tool output like "3 tests failed out of 10".
2026-02-20 10:30:03 +07:00
Zamil Majdy
17995596db poetry lock 2026-02-20 10:29:33 +07:00
Zamil Majdy
7acbbd0f05 poetry lock 2026-02-20 10:21:29 +07:00
Zamil Majdy
4be03fcc08 fix(copilot): remove redundant resolveInProgressTools frontend safety net
The backend already resolves all tool calls via wait_for_stash +
_flush_unresolved_tool_calls before StreamFinish, making the
streaming→ready transition cleanup unnecessary. The isComplete
hydration fix (for page refresh/crash recovery) is kept since it
covers a genuinely different failure mode.
2026-02-20 10:12:33 +07:00
Zamil Majdy
eb7bd6bdae fix(copilot): unify context-building logic for resume and non-resume paths
Consolidates the two separate context-injection paths (gap detection for
--resume, full compression for non-resume) into a single flow: determine
messages → compress → format → prepend. Renames _compress_conversation_history
to _compress_messages accepting a list directly.
2026-02-20 10:05:04 +07:00
Zamil Majdy
d81e7dd6c9 Merge branch 'dev' into fix/messed-up-copilot
Resolve service.py conflicts: take dev's file-based transcript approach
(CapturedTranscript.path + read_transcript_file) and public client API,
layer our fixes on top (wait_for_stash race-condition fix, session_id
logging, approach logging).
2026-02-20 09:54:07 +07:00
Zamil Majdy
78b52b956d fix(copilot): address PR review comments — runtime check, SDK version pin, event-based stash
- Replace bare `assert client._query` with proper RuntimeError check
- Add TECH DEBT comments on private SDK internal usage
- Pin claude-agent-sdk to ~0.1.35 (tighter constraint for private API access)
- Replace sleep(0.1) with event-based wait_for_stash() for race-condition fix
- Add wait_for_stash synchronisation tests
2026-02-20 09:46:19 +07:00
Zamil Majdy
e476185c3a fix(copilot): mitigate SDK hook race condition and improve diagnostic logging
- Add 100ms yield before flush when unresolved tool calls exist, giving
  PostToolUse hooks time to complete before the stash is checked. This
  mitigates the race condition in claude_agent_sdk where hooks are
  fire-and-forget (start_soon) while messages arrive immediately.
- Add has_unresolved_tool_calls property to SDKResponseAdapter
- Differentiate empty flush warnings to flag likely race conditions
- Add session_id to all SDK log messages ([SDK] [<session>] ...)
- Log session approach (resume/compression/single-turn) with context sizes
- Elevate session save log from debug to info
2026-02-20 09:32:30 +07:00
Zamil Majdy
b1c5000937 fix(copilot): improve tool flush logging, transcript capture, and stale spinner safety nets
- Elevate flush logging from debug to info/warning with structured messages
  showing tool names and IDs for production diagnostics
- Capture raw SDK output for transcript instead of relying on Stop hook
  file path (CLI doesn't write JSONL in SDK mode)
- Add _build_transcript() to reconstruct JSONL from captured entries
- Add isComplete option to hydration conversion — marks dangling tool calls
  as completed when session has no active stream (fixes stale spinners on
  page refresh)
- Add resolveInProgressTools safety net on streaming→ready transition
  (catches tool parts the backend didn't emit output for)
- Add 3 new tests for flush mechanism (ResultMessage, AssistantMessage,
  stashed output)
2026-02-20 08:56:56 +07:00
Zamil Majdy
7ee870ed70 fix(copilot): catch OSError in sandbox killpg to prevent zombie processes
Catch OSError broadly (not just ProcessLookupError) when calling
os.killpg so that EPERM or other errors don't skip the subsequent
await proc.communicate(), which would leave the subprocess un-reaped.
2026-02-20 02:49:51 +07:00
Zamil Majdy
240e403592 fix(copilot): fix transcript validation and resume test resilience
- Replace brittle line-count check (< 3) in read_transcript_file with
  proper validate_transcript() which checks for actual user/assistant
  entries — avoids rejecting valid short transcripts while still
  filtering metadata-only files
- Add debug logging for transcript source selection and fallback path
  to aid diagnosing resume issues in Docker
- Make test_sdk_resume_multi_turn skip gracefully when the CLI doesn't
  produce usable transcripts (environment-dependent: CLI version,
  platform) instead of hard-failing
2026-02-20 02:47:23 +07:00
Zamil Majdy
c3e94f7d9c fix(copilot): address review comments — counter order, error markers, tests
- Move task_spawn_count increment after limit check so denied spawns
  don't consume a slot (greptile feedback)
- Add "failed" marker to _is_tool_error_or_denial to catch internal
  tool execution failures from _mcp_error (coderabbit feedback)
- Add 17 unit tests for _is_tool_error_or_denial covering all markers,
  denial messages, and false-positive scenarios
2026-02-20 02:29:51 +07:00
Zamil Majdy
a0a040f102 fix(copilot): sandbox kill, tool event logging, and background task UX
- Fix sandbox process kill: use start_new_session + os.killpg to kill
  the entire bwrap process group on timeout (proc.kill alone only kills
  the parent, leaving children running until natural completion)
- Add StreamToolInputAvailable/StreamToolOutputAvailable to publish_chunk
  logging filter so tool events are visible in Docker logs
- Add system prompt instruction telling Claude not to use
  run_in_background on Task tool (gets denied by security hooks)
- Add tool event debug logging in SDK streaming loop for tracing
  tool execution visibility issues
2026-02-20 02:24:31 +07:00
Zamil Majdy
23225aa323 fix(copilot): address review comments — slot counting, heartbeat, error detection
- Move run_in_background check before task_spawn_count increment so
  denied background Tasks don't consume a subtask slot
- Replace asyncio.wait_for() with asyncio.timeout() for heartbeat loop
  to avoid leaving the async generator in a broken state
- Tighten _is_tool_error_or_denial: remove overly broad markers
  ("error", "failed", "not found") that cause false positives; add
  markers for actual denial messages ("not supported", "maximum")
2026-02-20 01:19:33 +07:00
Zamil Majdy
fed645cb79 Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into fix/messed-up-copilot 2026-02-20 01:14:46 +07:00
Zamil Majdy
009753f2b3 fix(copilot): prevent background agent stalls and context hallucination
- Block Task tool's run_in_background param in security hooks — background
  agents stall the SSE stream and get killed when the main agent exits
- Add heartbeats (15s interval) to SDK streaming loop so proxies/LBs don't
  close idle SSE connections during long tool execution
- Fix summarization prompt that forced LLM to fabricate content for all 9
  mandatory sections; now sections are optional and hallucination is
  explicitly prohibited
- Include tool error/denial outcomes in conversation context formatting —
  previously all tool messages were dropped, so the agent couldn't see
  that security denials blocked its file writes and hallucinated success
2026-02-19 23:28:22 +07:00
10 changed files with 116 additions and 279 deletions

View File

@@ -27,6 +27,7 @@ class ChatConfig(BaseSettings):
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# Streaming Configuration
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(
default=3,
description="Max retries for fallback path (SDK handles retries internally)",
@@ -38,10 +39,8 @@ class ChatConfig(BaseSettings):
# Long-running operation configuration
long_running_operation_ttl: int = Field(
default=3600,
description="TTL in seconds for long-running operation deduplication lock "
"(1 hour, matches stream_ttl). Prevents duplicate operations if pod dies. "
"For longer operations, the stream_registry heartbeat keeps them alive.",
default=600,
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
)
# Stream registry configuration for SSE reconnection

View File

@@ -132,97 +132,58 @@ async def add_chat_messages_batch(
session_id: str,
messages: list[dict[str, Any]],
start_sequence: int,
) -> tuple[list[ChatMessage], int]:
) -> list[ChatMessage]:
"""Add multiple messages to a chat session in a batch.
Uses collision detection with retry: tries to create messages starting
at start_sequence. If a unique constraint violation occurs (e.g., the
streaming loop and long-running callback race), queries MAX(sequence)
and retries with the correct next sequence number. This avoids
unnecessary upserts and DB queries in the common case (no collision).
Returns:
Tuple of (messages, final_message_count) where final_message_count
is the total number of messages in the session after insertion.
This allows callers to update their counters even when collision
detection adjusts start_sequence.
Uses a transaction for atomicity - if any message creation fails,
the entire batch is rolled back.
"""
if not messages:
# No messages to add - return current count
return [], start_sequence
return []
max_retries = 3
for attempt in range(max_retries):
try:
created_messages = []
async with db.transaction() as tx:
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields
# set to None. We only include fields that have values, then cast.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": start_sequence + i,
}
created_messages = []
# Add optional string fields
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
async with db.transaction() as tx:
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields
# set to None. We only include fields that have values, then cast.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": start_sequence + i,
}
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
# Add optional string fields
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
created = await PrismaChatMessage.prisma(tx).create(
data=cast(ChatMessageCreateInput, data)
)
created_messages.append(created)
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
# Update session's updatedAt timestamp within the same transaction.
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
# separately via update_chat_session() after streaming completes.
await PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
)
# Return messages and final message count (for shared counter sync)
final_count = start_sequence + len(messages)
return [ChatMessage.from_db(m) for m in created_messages], final_count
except Exception as e:
# Check if it's a unique constraint violation
error_msg = str(e).lower()
is_unique_constraint = (
"unique constraint" in error_msg or "duplicate key" in error_msg
created = await PrismaChatMessage.prisma(tx).create(
data=cast(ChatMessageCreateInput, data)
)
created_messages.append(created)
if is_unique_constraint and attempt < max_retries - 1:
# Collision detected - query MAX(sequence)+1 and retry with correct offset
logger.info(
f"Collision detected for session {session_id} at sequence "
f"{start_sequence}, querying DB for latest sequence"
)
start_sequence = await get_next_sequence(session_id)
logger.info(
f"Retrying batch insert with start_sequence={start_sequence}"
)
continue
else:
# Not a collision or max retries exceeded - propagate error
raise
# Update session's updatedAt timestamp within the same transaction.
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
# separately via update_chat_session() after streaming completes.
await PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
)
# Should never reach here due to raise in exception handler
raise RuntimeError(f"Failed to insert messages after {max_retries} attempts")
return [ChatMessage.from_db(m) for m in created_messages]
async def get_user_chat_sessions(
@@ -276,23 +237,10 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
return False
async def get_next_sequence(session_id: str) -> int:
"""Get the next sequence number for a new message in this session.
Uses MAX(sequence) + 1 for robustness. Returns 0 if no messages exist.
More robust than COUNT(*) because it's immune to deleted messages.
"""
result = await db.prisma.query_raw(
"""
SELECT COALESCE(MAX(sequence) + 1, 0) as next_seq
FROM "ChatMessage"
WHERE "sessionId" = $1
""",
session_id,
)
if not result or len(result) == 0:
return 0
return int(result[0]["next_seq"])
async def get_chat_session_message_count(session_id: str) -> int:
"""Get the number of messages in a chat session."""
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
return count
async def update_tool_message_content(

View File

@@ -266,11 +266,7 @@ class CoPilotProcessor:
except asyncio.CancelledError:
log.info("Task cancelled")
await stream_registry.mark_task_completed(
entry.task_id,
status="failed",
error_message="Task was cancelled",
)
await stream_registry.mark_task_completed(entry.task_id, status="failed")
raise
except Exception as e:

View File

@@ -436,7 +436,7 @@ async def upsert_chat_session(
session: ChatSession,
*,
existing_message_count: int | None = None,
) -> tuple[ChatSession, int]:
) -> ChatSession:
"""Update a chat session in both cache and database.
Uses session-level locking to prevent race conditions when concurrent
@@ -449,10 +449,6 @@ async def upsert_chat_session(
accurately. Useful for incremental saves in a streaming loop
where the caller already knows how many messages are persisted.
Returns:
Tuple of (session, final_message_count) where final_message_count is
the actual persisted message count after collision detection adjustments.
Raises:
DatabaseError: If the database write fails. The cache is still updated
as a best-effort optimization, but the error is propagated to ensure
@@ -465,16 +461,15 @@ async def upsert_chat_session(
async with lock:
# Get existing message count from DB for incremental saves
if existing_message_count is None:
existing_message_count = await chat_db().get_next_sequence(
existing_message_count = await chat_db().get_chat_session_message_count(
session.session_id
)
db_error: Exception | None = None
final_count = existing_message_count
# Save to database (primary storage)
try:
final_count = await _save_session_to_db(
await _save_session_to_db(
session,
existing_message_count,
skip_existence_check=existing_message_count > 0,
@@ -505,7 +500,7 @@ async def upsert_chat_session(
f"Failed to persist chat session {session.session_id} to database"
) from db_error
return session, final_count
return session
async def _save_session_to_db(
@@ -513,16 +508,13 @@ async def _save_session_to_db(
existing_message_count: int,
*,
skip_existence_check: bool = False,
) -> int:
) -> None:
"""Save or update a chat session in the database.
Args:
skip_existence_check: When True, skip the ``get_chat_session`` query
and assume the session row already exists. Saves one DB round trip
for incremental saves during streaming.
Returns:
Final message count after save (accounting for collision detection).
"""
db = chat_db()
@@ -554,7 +546,6 @@ async def _save_session_to_db(
# Add new messages (only those after existing count)
new_messages = session.messages[existing_message_count:]
final_count = existing_message_count
if new_messages:
messages_data = []
for msg in new_messages:
@@ -574,14 +565,12 @@ async def _save_session_to_db(
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
_, final_count = await db.add_chat_messages_batch(
await db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,
start_sequence=existing_message_count,
)
return final_count
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
"""Atomically append a message to a session and persist it.
@@ -598,7 +587,9 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
raise ValueError(f"Session {session_id} not found")
session.messages.append(message)
existing_message_count = await chat_db().get_next_sequence(session_id)
existing_message_count = await chat_db().get_chat_session_message_count(
session_id
)
try:
await _save_session_to_db(session, existing_message_count)

View File

@@ -60,7 +60,7 @@ async def test_chatsession_redis_storage(setup_test_user, test_user_id):
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s, _ = await upsert_chat_session(s)
s = await upsert_chat_session(s)
s2 = await get_chat_session(
session_id=s.session_id,
@@ -77,7 +77,7 @@ async def test_chatsession_redis_storage_user_id_mismatch(
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s, _ = await upsert_chat_session(s)
s = await upsert_chat_session(s)
s2 = await get_chat_session(s.session_id, "different_user_id")
@@ -94,7 +94,7 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
s.messages = messages # Contains user, assistant, and tool messages
assert s.session_id is not None, "Session id is not set"
# Upsert to save to both cache and DB
s, _ = await upsert_chat_session(s)
s = await upsert_chat_session(s)
# Clear the Redis cache to force DB load
redis_key = f"chat:session:{s.session_id}"

View File

@@ -9,7 +9,6 @@ from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any
from backend.data.redis_client import get_redis_async
from backend.util.exceptions import NotFoundError
from .. import stream_registry
@@ -133,65 +132,8 @@ is delivered to the user via a background stream.
All tasks must run in the foreground.
"""
# Session streaming lock configuration
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
STREAM_LOCK_TTL = 3600 # 1 hour - matches stream_ttl
async def _acquire_stream_lock(session_id: str, stream_id: str) -> bool:
"""Acquire an exclusive lock for streaming to this session.
Prevents multiple concurrent streams to the same session which can cause:
- Message duplication
- Race conditions in message saves
- Confusing UX with multiple AI responses
Returns:
True if lock was acquired, False if another stream is active.
"""
redis = await get_redis_async()
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
# SET NX EX - atomic "set if not exists" with expiry
result = await redis.set(lock_key, stream_id, ex=STREAM_LOCK_TTL, nx=True)
return result is not None
async def _release_stream_lock(session_id: str, stream_id: str) -> None:
"""Release the stream lock if we still own it.
Only releases the lock if the stored stream_id matches ours (prevents
releasing another stream's lock if we somehow timed out).
"""
redis = await get_redis_async()
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
# Lua script for atomic compare-and-delete (only delete if value matches)
script = """
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
else
return 0
end
"""
await redis.eval(script, 1, lock_key, stream_id) # type: ignore[misc]
async def check_active_stream(session_id: str) -> str | None:
"""Check if a stream is currently active for this session.
Returns:
The active stream_id if one exists, None otherwise.
"""
redis = await get_redis_async()
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
active_stream = await redis.get(lock_key)
return active_stream.decode() if isinstance(active_stream, bytes) else active_stream
def _build_long_running_callback(
user_id: str | None,
saved_msg_count_ref: list[int] | None = None,
) -> LongRunningCallback:
def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
@@ -200,12 +142,6 @@ def _build_long_running_callback(
page refreshes / pod restarts, and the frontend shows the proper loading
widget with progress updates.
Args:
user_id: User ID for the session
saved_msg_count_ref: Mutable reference [count] shared with streaming loop
for coordinating message saves. When provided, the callback will update
it after appending messages to prevent counter drift.
The returned callback matches the ``LongRunningCallback`` signature:
``(tool_name, args, session) -> MCP response dict``.
"""
@@ -271,11 +207,7 @@ def _build_long_running_callback(
tool_call_id=tool_call_id,
)
session.messages.append(pending_message)
# Collision detection happens in add_chat_messages_batch (db.py)
_, final_count = await upsert_chat_session(session)
# Update shared counter so streaming loop stays in sync
if saved_msg_count_ref is not None:
saved_msg_count_ref[0] = final_count
await upsert_chat_session(session)
# --- Spawn background task (reuses non-SDK infrastructure) ---
bg_task = asyncio.create_task(
@@ -610,7 +542,7 @@ async def stream_chat_completion_sdk(
user_id=user_id, session_id=session_id, message_length=len(message)
)
session, _ = await upsert_chat_session(session)
session = await upsert_chat_session(session)
# Generate title for new sessions (first user message)
if is_user_message and not session.title:
@@ -632,23 +564,6 @@ async def stream_chat_completion_sdk(
system_prompt += _SDK_TOOL_SUPPLEMENT
message_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
stream_id = task_id # Use task_id as unique stream identifier
# Acquire stream lock to prevent concurrent streams to the same session
lock_acquired = await _acquire_stream_lock(session_id, stream_id)
if not lock_acquired:
# Another stream is active - check if it's still alive
active_stream = await check_active_stream(session_id)
logger.warning(
f"[SDK] Session {session_id} already has an active stream: {active_stream}"
)
yield StreamError(
errorText="Another stream is already active for this session. "
"Please wait for it to complete or refresh the page.",
code="stream_already_active",
)
yield StreamFinish()
return
yield StreamStart(messageId=message_id, taskId=task_id)
@@ -666,16 +581,10 @@ async def stream_chat_completion_sdk(
sdk_cwd = _make_sdk_cwd(session_id)
os.makedirs(sdk_cwd, exist_ok=True)
# Initialize saved message counter as mutable list so long-running
# callback and streaming loop can coordinate
saved_msg_count_ref: list[int] = [len(session.messages)]
set_execution_context(
user_id,
session,
long_running_callback=_build_long_running_callback(
user_id, saved_msg_count_ref
),
long_running_callback=_build_long_running_callback(user_id),
)
try:
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
@@ -806,8 +715,9 @@ async def stream_chat_completion_sdk(
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
# Track persisted message count. Uses shared ref so long-running
# callback can update it for coordination
# Track persisted message count to skip DB count queries
# on incremental saves. Initial save happened at line 545.
saved_msg_count = len(session.messages)
# Use an explicit async iterator with non-cancelling heartbeats.
# CRITICAL: we must NOT cancel __anext__() mid-flight — doing so
@@ -983,12 +893,13 @@ async def stream_chat_completion_sdk(
has_appended_assistant = True
# Save before tool execution starts so the
# pending tool call is visible on refresh /
# other devices. Collision detection happens
# in add_chat_messages_batch (db.py).
# other devices.
try:
_, final_count = await upsert_chat_session(session)
# Update shared ref so callback stays in sync
saved_msg_count_ref[0] = final_count
await upsert_chat_session(
session,
existing_message_count=saved_msg_count,
)
saved_msg_count = len(session.messages)
except Exception as save_err:
logger.warning(
"[SDK] [%s] Incremental save " "failed: %s",
@@ -1011,11 +922,12 @@ async def stream_chat_completion_sdk(
has_tool_results = True
# Save after tool completes so the result is
# visible on refresh / other devices.
# Collision detection happens in add_chat_messages_batch (db.py).
try:
_, final_count = await upsert_chat_session(session)
# Update shared ref so callback stays in sync
saved_msg_count_ref[0] = final_count
await upsert_chat_session(
session,
existing_message_count=saved_msg_count,
)
saved_msg_count = len(session.messages)
except Exception as save_err:
logger.warning(
"[SDK] [%s] Incremental save " "failed: %s",
@@ -1147,12 +1059,11 @@ async def stream_chat_completion_sdk(
"to use the OpenAI-compatible fallback."
)
_, final_count = await asyncio.shield(upsert_chat_session(session))
await asyncio.shield(upsert_chat_session(session))
logger.info(
"[SDK] [%s] Session saved with %d messages (DB count: %d)",
"[SDK] [%s] Session saved with %d messages",
session_id[:12],
len(session.messages),
final_count,
)
if not stream_completed:
yield StreamFinish()
@@ -1210,9 +1121,6 @@ async def stream_chat_completion_sdk(
if sdk_cwd:
_cleanup_sdk_tool_results(sdk_cwd)
# Release stream lock to allow new streams for this session
await _release_stream_lock(session_id, stream_id)
async def _try_upload_transcript(
user_id: str,

View File

@@ -352,8 +352,7 @@ async def assign_user_to_session(
if not session:
raise NotFoundError(f"Session {session_id} not found")
session.user_id = user_id
session, _ = await upsert_chat_session(session)
return session
return await upsert_chat_session(session)
async def stream_chat_completion(
@@ -464,7 +463,7 @@ async def stream_chat_completion(
)
upsert_start = time.monotonic()
session, _ = await upsert_chat_session(session)
session = await upsert_chat_session(session)
upsert_time = (time.monotonic() - upsert_start) * 1000
logger.info(
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
@@ -690,7 +689,7 @@ async def stream_chat_completion(
f"tool_responses={len(tool_response_messages)}"
)
if messages_to_save_early or has_appended_streaming_message:
_ = await upsert_chat_session(session)
await upsert_chat_session(session)
has_saved_assistant_message = True
has_yielded_end = True
@@ -729,7 +728,7 @@ async def stream_chat_completion(
if tool_response_messages:
session.messages.extend(tool_response_messages)
try:
_ = await upsert_chat_session(session)
await upsert_chat_session(session)
except Exception as e:
logger.warning(
f"Failed to save interrupted session {session.session_id}: {e}"
@@ -770,7 +769,7 @@ async def stream_chat_completion(
if messages_to_save:
session.messages.extend(messages_to_save)
if messages_to_save or has_appended_streaming_message:
_ = await upsert_chat_session(session)
await upsert_chat_session(session)
if not has_yielded_error:
error_message = str(e)
@@ -854,7 +853,7 @@ async def stream_chat_completion(
not has_long_running_tool_call
and (messages_to_save or has_appended_streaming_message)
):
_ = await upsert_chat_session(session)
await upsert_chat_session(session)
else:
logger.info(
"Assistant message already saved when StreamFinish was received, "
@@ -1526,7 +1525,7 @@ async def _yield_tool_call(
tool_call_id=tool_call_id,
)
session.messages.append(pending_message)
_ = await upsert_chat_session(session)
await upsert_chat_session(session)
await _with_optional_lock(session_lock, _save_pending)
logger.info(
@@ -1564,11 +1563,7 @@ async def _yield_tool_call(
await _mark_operation_completed(tool_call_id)
# Mark stream registry task as failed if it was created
try:
await stream_registry.mark_task_completed(
task_id,
status="failed",
error_message=f"Failed to setup tool {tool_name}: {e}",
)
await stream_registry.mark_task_completed(task_id, status="failed")
except Exception as mark_err:
logger.warning(f"Failed to mark task {task_id} as failed: {mark_err}")
logger.error(
@@ -1736,11 +1731,7 @@ async def _execute_long_running_tool_with_streaming(
session = await get_chat_session(session_id, user_id)
if not session:
logger.error(f"Session {session_id} not found for background tool")
await stream_registry.mark_task_completed(
task_id,
status="failed",
error_message=f"Session {session_id} not found",
)
await stream_registry.mark_task_completed(task_id, status="failed")
return
# Pass operation_id and task_id to the tool for async processing
@@ -2020,7 +2011,7 @@ async def _generate_llm_continuation(
fresh_session.messages.append(assistant_message)
# Save to database (not cache) to persist the response
_ = await upsert_chat_session(fresh_session)
await upsert_chat_session(fresh_session)
# Invalidate cache so next poll/refresh gets fresh data
await invalidate_session_cache(session_id)
@@ -2226,7 +2217,7 @@ async def _generate_llm_continuation_with_streaming(
fresh_session.messages.append(assistant_message)
# Save to database (not cache) to persist the response
_ = await upsert_chat_session(fresh_session)
await upsert_chat_session(fresh_session)
# Invalidate cache so next poll/refresh gets fresh data
await invalidate_session_cache(session_id)

View File

@@ -58,7 +58,7 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session, _ = await upsert_chat_session(session)
session = await upsert_chat_session(session)
has_errors = False
has_ended = False
@@ -104,7 +104,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
session = await create_chat_session(test_user_id)
session, _ = await upsert_chat_session(session)
session = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---
keyword = "ZEPHYR42"

View File

@@ -644,8 +644,6 @@ async def _stream_listener(
async def mark_task_completed(
task_id: str,
status: Literal["completed", "failed"] = "completed",
*,
error_message: str | None = None,
) -> bool:
"""Mark a task as completed and publish finish event.
@@ -656,10 +654,6 @@ async def mark_task_completed(
Args:
task_id: Task ID to mark as completed
status: Final status ("completed" or "failed")
error_message: If provided and status="failed", publish a StreamError
before StreamFinish so connected clients see why the task ended.
If not provided, no StreamError is published (caller should publish
manually if needed to avoid duplicates).
Returns:
True if task was newly marked completed, False if already completed/failed
@@ -675,17 +669,6 @@ async def mark_task_completed(
logger.debug(f"Task {task_id} already completed/failed, skipping")
return False
# Publish error event before finish so connected clients know WHY the
# task ended. Only publish if caller provided an explicit error message
# to avoid duplicates with code paths that manually publish StreamError.
# This is best-effort — if it fails, the StreamFinish still ensures
# listeners clean up.
if status == "failed" and error_message:
try:
await publish_chunk(task_id, StreamError(errorText=error_message))
except Exception as e:
logger.warning(f"Failed to publish error event for task {task_id}: {e}")
# THEN publish finish event (best-effort - listeners can detect via status polling)
try:
await publish_chunk(task_id, StreamFinish())
@@ -838,6 +821,27 @@ async def get_active_task_for_session(
if task_user_id and user_id != task_user_id:
continue
# Auto-expire stale tasks that exceeded stream_timeout
created_at_str = meta.get("created_at", "")
if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str)
age_seconds = (
datetime.now(timezone.utc) - created_at
).total_seconds()
if age_seconds > config.stream_timeout:
logger.warning(
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
)
await mark_task_completed(task_id, "failed")
continue
except (ValueError, TypeError) as exc:
logger.warning(
f"[TASK_LOOKUP] Failed to parse created_at "
f"for task {task_id[:8]}...: {exc}"
)
logger.info(
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
)

View File

@@ -303,7 +303,7 @@ class DatabaseManager(AppService):
get_user_chat_sessions = _(chat_db.get_user_chat_sessions)
get_user_session_count = _(chat_db.get_user_session_count)
delete_chat_session = _(chat_db.delete_chat_session)
get_next_sequence = _(chat_db.get_next_sequence)
get_chat_session_message_count = _(chat_db.get_chat_session_message_count)
update_tool_message_content = _(chat_db.update_tool_message_content)
@@ -473,5 +473,5 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_user_chat_sessions = d.get_user_chat_sessions
get_user_session_count = d.get_user_session_count
delete_chat_session = d.delete_chat_session
get_next_sequence = d.get_next_sequence
get_chat_session_message_count = d.get_chat_session_message_count
update_tool_message_content = d.update_tool_message_content