mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(copilot): remove legacy copilot, add baseline non-SDK mode with tool calling (#12276)
## Summary - Remove ~1200 lines of broken/unmaintained non-SDK copilot streaming code (retry logic, parallel tool calls, context window management) - Add `stream_chat_completion_baseline()` as a clean fallback LLM path with full tool-calling support when `CHAT_USE_CLAUDE_AGENT_SDK=false` (e.g. when Anthropic is down) - Baseline reuses the same shared `TOOL_REGISTRY`, `get_available_tools()`, and `execute_tool()` as the SDK path - Move baseline code to dedicated `baseline/` folder (mirrors `sdk/` structure) - Clean up SDK service: remove unused params, fix model/env resolution, fix stream error persistence - Clean up config: remove `max_retries`, `thinking_enabled` fields (non-SDK only) ## Changes | File | Action | |------|--------| | `backend/copilot/baseline/__init__.py` | New — package export | | `backend/copilot/baseline/service.py` | New — baseline streaming with tool-call loop | | `backend/copilot/baseline/service_test.py` | New — multi-turn keyword recall test | | `backend/copilot/service.py` | Remove ~1200 lines of legacy code, keep shared helpers only | | `backend/copilot/executor/processor.py` | Simplify branching to SDK vs baseline | | `backend/copilot/sdk/service.py` | Remove unused params, fix model/env separation, fix stream error persistence | | `backend/copilot/config.py` | Remove `max_retries`, `thinking_enabled` | | `backend/copilot/service_test.py` | Keep SDK test only (baseline test moved) | | `backend/copilot/parallel_tool_calls_test.py` | Deleted (tested removed code) | ## Test plan - [x] `poetry run format` passes - [x] CI passes (all 3 Python versions, types, CodeQL) - [ ] SDK path works unchanged in production - [x] Baseline path (`CHAT_USE_CLAUDE_AGENT_SDK=false`) streams responses with tool calling - [x] Baseline emits correct Vercel AI SDK stream protocol events
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from .service import stream_chat_completion_baseline
|
||||
|
||||
__all__ = ["stream_chat_completion_baseline"]
|
||||
398
autogpt_platform/backend/backend/copilot/baseline/service.py
Normal file
398
autogpt_platform/backend/backend/copilot/baseline/service.py
Normal file
@@ -0,0 +1,398 @@
|
||||
"""Baseline LLM fallback — OpenAI-compatible streaming with tool calling.
|
||||
|
||||
Used when ``CHAT_USE_CLAUDE_AGENT_SDK=false``, e.g. as a fallback when the
|
||||
Claude Agent SDK / Anthropic API is unavailable. Routes through any
|
||||
OpenAI-compatible provider (OpenRouter by default) and reuses the same
|
||||
shared tool registry as the SDK path.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import (
|
||||
_build_system_prompt,
|
||||
_generate_session_title,
|
||||
client,
|
||||
config,
|
||||
)
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import compress_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Set to hold background tasks to prevent garbage collection
|
||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
# Maximum number of tool-call rounds before forcing a text response.
|
||||
_MAX_TOOL_ROUNDS = 30
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None
|
||||
) -> None:
|
||||
"""Generate and persist a session title in the background."""
|
||||
try:
|
||||
title = await _generate_session_title(message, user_id, session_id)
|
||||
if title:
|
||||
await update_session_title(session_id, title)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Failed to update session title: %s", e)
|
||||
|
||||
|
||||
async def _compress_session_messages(
|
||||
messages: list[ChatMessage],
|
||||
) -> list[ChatMessage]:
|
||||
"""Compress session messages if they exceed the model's token limit.
|
||||
|
||||
Uses the shared compress_context() utility which supports LLM-based
|
||||
summarization of older messages while keeping recent ones intact,
|
||||
with progressive truncation and middle-out deletion as fallbacks.
|
||||
"""
|
||||
messages_dict = []
|
||||
for msg in messages:
|
||||
msg_dict: dict[str, Any] = {"role": msg.role}
|
||||
if msg.content:
|
||||
msg_dict["content"] = msg.content
|
||||
messages_dict.append(msg_dict)
|
||||
|
||||
try:
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=None,
|
||||
)
|
||||
|
||||
if result.was_compacted:
|
||||
logger.info(
|
||||
"[Baseline] Context compacted: %d -> %d tokens "
|
||||
"(%d summarized, %d dropped)",
|
||||
result.original_token_count,
|
||||
result.token_count,
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
return [
|
||||
ChatMessage(role=m["role"], content=m.get("content"))
|
||||
for m in result.messages
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
async def stream_chat_completion_baseline(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
session: ChatSession | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Baseline LLM with tool calling via OpenAI-compatible API.
|
||||
|
||||
Designed as a fallback when the Claude Agent SDK is unavailable.
|
||||
Uses the same tool registry as the SDK path but routes through any
|
||||
OpenAI-compatible provider (e.g. OpenRouter).
|
||||
|
||||
Flow: stream response -> if tool_calls, execute them -> feed results back -> repeat.
|
||||
"""
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
f"Session {session_id} not found. Please create a new session first."
|
||||
)
|
||||
|
||||
# Append user message
|
||||
new_role = "user" if is_user_message else "assistant"
|
||||
if message and (
|
||||
len(session.messages) == 0
|
||||
or not (
|
||||
session.messages[-1].role == new_role
|
||||
and session.messages[-1].content == message
|
||||
)
|
||||
):
|
||||
session.messages.append(ChatMessage(role=new_role, content=message))
|
||||
if is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(message),
|
||||
)
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Generate title for new sessions
|
||||
if is_user_message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
if len(user_messages) == 1:
|
||||
first_message = user_messages[0].content or message or ""
|
||||
if first_message:
|
||||
task = asyncio.create_task(
|
||||
_update_title_async(session_id, first_message, user_id)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# Build system prompt only on the first turn to avoid mid-conversation
|
||||
# changes from concurrent chats updating business understanding.
|
||||
is_first_turn = len(session.messages) <= 1
|
||||
if is_first_turn:
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=False
|
||||
)
|
||||
else:
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
user_id=None, has_conversation_history=True
|
||||
)
|
||||
|
||||
# Compress context if approaching the model's token limit
|
||||
messages_for_context = await _compress_session_messages(session.messages)
|
||||
|
||||
# Build OpenAI message list from session history
|
||||
openai_messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt}
|
||||
]
|
||||
for msg in messages_for_context:
|
||||
if msg.role in ("user", "assistant") and msg.content:
|
||||
openai_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
tools = get_available_tools()
|
||||
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
assistant_text = ""
|
||||
text_block_id = str(uuid.uuid4())
|
||||
text_started = False
|
||||
step_open = False
|
||||
try:
|
||||
for _round in range(_MAX_TOOL_ROUNDS):
|
||||
# Open a new step for each LLM round
|
||||
yield StreamStartStep()
|
||||
step_open = True
|
||||
|
||||
# Stream a response from the model
|
||||
create_kwargs: dict[str, Any] = dict(
|
||||
model=config.model,
|
||||
messages=openai_messages,
|
||||
stream=True,
|
||||
)
|
||||
if tools:
|
||||
create_kwargs["tools"] = tools
|
||||
response = await client.chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
|
||||
# Accumulate streamed response (text + tool calls)
|
||||
round_text = ""
|
||||
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in response:
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if not delta:
|
||||
continue
|
||||
|
||||
# Text content
|
||||
if delta.content:
|
||||
if not text_started:
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
text_started = True
|
||||
round_text += delta.content
|
||||
yield StreamTextDelta(id=text_block_id, delta=delta.content)
|
||||
|
||||
# Tool call fragments (streamed incrementally)
|
||||
if delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
if idx not in tool_calls_by_index:
|
||||
tool_calls_by_index[idx] = {
|
||||
"id": "",
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
entry = tool_calls_by_index[idx]
|
||||
if tc.id:
|
||||
entry["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
entry["name"] = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
entry["arguments"] += tc.function.arguments
|
||||
|
||||
# Close text block if we had one this round
|
||||
if text_started:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_started = False
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
# Accumulate text for session persistence
|
||||
assistant_text += round_text
|
||||
|
||||
# No tool calls -> model is done
|
||||
if not tool_calls_by_index:
|
||||
yield StreamFinishStep()
|
||||
step_open = False
|
||||
break
|
||||
|
||||
# Close step before tool execution
|
||||
yield StreamFinishStep()
|
||||
step_open = False
|
||||
|
||||
# Append the assistant message with tool_calls to context
|
||||
assistant_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if round_text:
|
||||
assistant_msg["content"] = round_text
|
||||
assistant_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": tc["arguments"],
|
||||
},
|
||||
}
|
||||
for tc in tool_calls_by_index.values()
|
||||
]
|
||||
openai_messages.append(assistant_msg)
|
||||
|
||||
# Execute each tool call and stream events
|
||||
for tc in tool_calls_by_index.values():
|
||||
tool_call_id = tc["id"]
|
||||
tool_name = tc["name"]
|
||||
raw_args = tc["arguments"] or "{}"
|
||||
try:
|
||||
tool_args = orjson.loads(raw_args)
|
||||
except orjson.JSONDecodeError as parse_err:
|
||||
parse_error = (
|
||||
f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
|
||||
)
|
||||
logger.warning("[Baseline] %s", parse_error)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=parse_error,
|
||||
success=False,
|
||||
)
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": parse_error,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
yield StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
input=tool_args,
|
||||
)
|
||||
|
||||
# Execute via shared tool registry
|
||||
try:
|
||||
result: StreamToolOutputAvailable = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=tool_args,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
yield result
|
||||
tool_output = (
|
||||
result.output
|
||||
if isinstance(result.output, str)
|
||||
else str(result.output)
|
||||
)
|
||||
except Exception as e:
|
||||
error_output = f"Tool execution error: {e}"
|
||||
logger.error(
|
||||
"[Baseline] Tool %s failed: %s",
|
||||
tool_name,
|
||||
error_output,
|
||||
exc_info=True,
|
||||
)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=error_output,
|
||||
success=False,
|
||||
)
|
||||
tool_output = error_output
|
||||
|
||||
# Append tool result to context for next round
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": tool_output,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# for-loop exhausted without break -> tool-round limit hit
|
||||
limit_msg = (
|
||||
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
|
||||
"without a final response."
|
||||
)
|
||||
logger.error("[Baseline] %s", limit_msg)
|
||||
yield StreamError(
|
||||
errorText=limit_msg,
|
||||
code="baseline_tool_round_limit",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
|
||||
# Close any open text/step before emitting error
|
||||
if text_started:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
if step_open:
|
||||
yield StreamFinishStep()
|
||||
yield StreamError(errorText=error_msg, code="baseline_error")
|
||||
# Still persist whatever we got
|
||||
finally:
|
||||
# Persist assistant response
|
||||
if assistant_text:
|
||||
session.messages.append(
|
||||
ChatMessage(role="assistant", content=assistant_text)
|
||||
)
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception as persist_err:
|
||||
logger.error("[Baseline] Failed to persist session: %s", persist_err)
|
||||
|
||||
yield StreamFinish()
|
||||
@@ -0,0 +1,99 @@
|
||||
import logging
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.model import (
|
||||
create_chat_session,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_baseline_multi_turn(setup_test_user, test_user_id):
|
||||
"""Test that the baseline LLM path streams responses and maintains history.
|
||||
|
||||
Turn 1: Send a message with a unique keyword.
|
||||
Turn 2: Ask the model to recall the keyword — proving conversation history
|
||||
is correctly passed to the single-call LLM.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# --- Turn 1: send a message with a unique keyword ---
|
||||
keyword = "QUASAR99"
|
||||
turn1_msg = (
|
||||
f"Please remember this special keyword: {keyword}. "
|
||||
"Just confirm you've noted it, keep your response brief."
|
||||
)
|
||||
turn1_text = ""
|
||||
turn1_errors: list[str] = []
|
||||
got_start = False
|
||||
got_finish = False
|
||||
|
||||
async for chunk in stream_chat_completion_baseline(
|
||||
session.session_id,
|
||||
turn1_msg,
|
||||
user_id=test_user_id,
|
||||
):
|
||||
if isinstance(chunk, StreamStart):
|
||||
got_start = True
|
||||
elif isinstance(chunk, StreamTextDelta):
|
||||
turn1_text += chunk.delta
|
||||
elif isinstance(chunk, StreamError):
|
||||
turn1_errors.append(chunk.errorText)
|
||||
elif isinstance(chunk, StreamFinish):
|
||||
got_finish = True
|
||||
|
||||
assert got_start, "Turn 1 did not yield StreamStart"
|
||||
assert got_finish, "Turn 1 did not yield StreamFinish"
|
||||
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
|
||||
assert turn1_text, "Turn 1 produced no text"
|
||||
logger.info(f"Turn 1 response: {turn1_text[:100]}")
|
||||
|
||||
# Reload session for turn 2
|
||||
session = await get_chat_session(session.session_id, test_user_id)
|
||||
assert session, "Session not found after turn 1"
|
||||
|
||||
# Verify messages were persisted (user + assistant)
|
||||
assert (
|
||||
len(session.messages) >= 2
|
||||
), f"Expected at least 2 messages after turn 1, got {len(session.messages)}"
|
||||
|
||||
# --- Turn 2: ask model to recall the keyword ---
|
||||
turn2_msg = "What was the special keyword I asked you to remember?"
|
||||
turn2_text = ""
|
||||
turn2_errors: list[str] = []
|
||||
|
||||
async for chunk in stream_chat_completion_baseline(
|
||||
session.session_id,
|
||||
turn2_msg,
|
||||
user_id=test_user_id,
|
||||
session=session,
|
||||
):
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
turn2_text += chunk.delta
|
||||
elif isinstance(chunk, StreamError):
|
||||
turn2_errors.append(chunk.errorText)
|
||||
|
||||
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
|
||||
assert turn2_text, "Turn 2 produced no text"
|
||||
assert keyword in turn2_text, (
|
||||
f"Model did not recall keyword '{keyword}' in turn 2. "
|
||||
f"Response: {turn2_text[:200]}"
|
||||
)
|
||||
logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}")
|
||||
@@ -26,11 +26,6 @@ class ChatConfig(BaseSettings):
|
||||
# Session TTL Configuration - 12 hours
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# Streaming Configuration
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Max retries for fallback path (SDK handles retries internally)",
|
||||
)
|
||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||
max_agent_schedules: int = Field(
|
||||
default=30, description="Maximum number of agent schedules"
|
||||
@@ -71,7 +66,7 @@ class ChatConfig(BaseSettings):
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
description="Use Claude Agent SDK for chat completions",
|
||||
description="Use Claude Agent SDK (True) or OpenAI-compatible LLM baseline (False)",
|
||||
)
|
||||
claude_agent_model: str | None = Field(
|
||||
default=None,
|
||||
@@ -113,12 +108,6 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B sandbox keepalive timeout in seconds.",
|
||||
)
|
||||
|
||||
# Extended thinking configuration for Claude models
|
||||
thinking_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable adaptive thinking for Claude models via OpenRouter",
|
||||
)
|
||||
|
||||
@field_validator("use_e2b_sandbox", mode="before")
|
||||
@classmethod
|
||||
def get_use_e2b_sandbox(cls, v):
|
||||
|
||||
@@ -9,8 +9,8 @@ import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
from backend.copilot import service as copilot_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.response_model import StreamFinish
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
@@ -194,7 +194,7 @@ class CoPilotProcessor:
|
||||
):
|
||||
"""Async execution logic for a CoPilot turn.
|
||||
|
||||
Calls the stream_chat_completion service function and publishes
|
||||
Calls the chat completion service (SDK or baseline) and publishes
|
||||
results to the stream registry.
|
||||
|
||||
Args:
|
||||
@@ -218,9 +218,9 @@ class CoPilotProcessor:
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else copilot_service.stream_chat_completion
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'standard'} service")
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
async for chunk in stream_fn(
|
||||
@@ -228,7 +228,6 @@ class CoPilotProcessor:
|
||||
message=entry.message if entry.message else None,
|
||||
is_user_message=entry.is_user_message,
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
):
|
||||
if cancel.is_set():
|
||||
log.info("Cancel requested, breaking stream")
|
||||
|
||||
@@ -1,269 +0,0 @@
|
||||
"""Tests for parallel tool call execution in CoPilot.
|
||||
|
||||
These tests mock _yield_tool_call to avoid importing the full copilot stack
|
||||
which requires Prisma, DB connections, etc.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls_run_concurrently():
|
||||
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
n_tools = 3
|
||||
delay_per_tool = 0.2
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"tool_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(n_tools)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
original_yield = None
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"],
|
||||
toolName=tc_list[idx]["function"]["name"],
|
||||
input={},
|
||||
)
|
||||
await asyncio.sleep(delay_per_tool)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"],
|
||||
toolName=tc_list[idx]["function"]["name"],
|
||||
output="{}",
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
original_yield = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
start = time.monotonic()
|
||||
events = []
|
||||
async for event in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
events.append(event)
|
||||
elapsed = time.monotonic() - start
|
||||
finally:
|
||||
svc._yield_tool_call = original_yield
|
||||
|
||||
assert len(events) == n_tools * 2
|
||||
# Parallel: should take ~delay, not ~n*delay
|
||||
assert elapsed < delay_per_tool * (
|
||||
n_tools - 0.5
|
||||
), f"Took {elapsed:.2f}s, expected parallel (~{delay_per_tool}s)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_tool_call_works():
|
||||
"""Single tool call should work identically."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": "call_0",
|
||||
"type": "function",
|
||||
"function": {"name": "t", "arguments": "{}"},
|
||||
}
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
|
||||
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
events = [
|
||||
e
|
||||
async for e in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
)
|
||||
]
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
assert len(events) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retryable_error_propagates():
|
||||
"""Retryable errors should be raised after all tools finish."""
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
if idx == 1:
|
||||
raise KeyError("bad")
|
||||
from backend.copilot.response_model import StreamToolInputAvailable
|
||||
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName="t_0", input={}
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName="t_0", output="{}"
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
events = []
|
||||
with pytest.raises(KeyError):
|
||||
async for event in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
events.append(event)
|
||||
# First tool's events should still be yielded
|
||||
assert any(isinstance(e, StreamToolOutputAvailable) for e in events)
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_shared_across_parallel_tools():
|
||||
"""All parallel tools should receive the same session instance."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
observed_sessions = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
observed_sessions.append(sess)
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", output="{}"
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
async for _ in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
pass
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
assert len(observed_sessions) == 3
|
||||
assert observed_sessions[0] is observed_sessions[1] is observed_sessions[2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancellation_cleans_up():
|
||||
"""Generator close should cancel in-flight tasks."""
|
||||
from backend.copilot.response_model import StreamToolInputAvailable
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
started.set()
|
||||
await asyncio.sleep(10) # simulate long-running
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
gen = _execute_tool_calls_parallel(tool_calls, cast(Any, FakeSession()))
|
||||
await gen.__anext__() # get first event
|
||||
await started.wait()
|
||||
await gen.aclose() # close generator
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
# If we get here without hanging, cleanup worked
|
||||
@@ -571,17 +571,12 @@ async def _build_query_message(
|
||||
async def stream_chat_completion_sdk(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
tool_call_response: str | None = None, # noqa: ARG001
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
retry_count: int = 0, # noqa: ARG001
|
||||
session: ChatSession | None = None,
|
||||
context: dict[str, str] | None = None, # noqa: ARG001
|
||||
**_kwargs: Any,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream chat completion using Claude Agent SDK.
|
||||
|
||||
Drop-in replacement for stream_chat_completion with improved reliability.
|
||||
"""
|
||||
"""Stream chat completion using Claude Agent SDK."""
|
||||
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
@@ -728,495 +723,491 @@ async def stream_chat_completion_sdk(
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
set_execution_context(user_id, session, sandbox=e2b_sandbox, sdk_cwd=sdk_cwd)
|
||||
try:
|
||||
# Fail fast when no API credentials are available at all
|
||||
sdk_env = _build_sdk_env(session_id=session_id, user_id=user_id)
|
||||
if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"):
|
||||
raise RuntimeError(
|
||||
"No API key configured. Set OPEN_ROUTER_API_KEY "
|
||||
"(or CHAT_API_KEY) for OpenRouter routing, "
|
||||
"or ANTHROPIC_API_KEY for direct Anthropic access."
|
||||
)
|
||||
|
||||
mcp_server = create_copilot_mcp_server(use_e2b=use_e2b)
|
||||
|
||||
sdk_model = _resolve_sdk_model()
|
||||
|
||||
# --- Transcript capture via Stop hook ---
|
||||
# Read the file content immediately — the SDK may clean up
|
||||
# the file before our finally block runs.
|
||||
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
|
||||
captured_transcript.path = transcript_path
|
||||
captured_transcript.sdk_session_id = sdk_session_id
|
||||
content = read_transcript_file(transcript_path)
|
||||
if content:
|
||||
captured_transcript.raw_content = content
|
||||
logger.info(
|
||||
f"[SDK] Stop hook: captured {len(content)}B from "
|
||||
f"{transcript_path}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] Stop hook: transcript file empty/missing at "
|
||||
f"{transcript_path}"
|
||||
)
|
||||
|
||||
# Track SDK-internal compaction (PreCompact hook → start, next msg → end)
|
||||
compaction = CompactionTracker()
|
||||
|
||||
security_hooks = create_security_hooks(
|
||||
user_id,
|
||||
sdk_cwd=sdk_cwd,
|
||||
max_subtasks=config.claude_agent_max_subtasks,
|
||||
on_stop=_on_stop if config.claude_agent_use_resume else None,
|
||||
on_compact=compaction.on_compact,
|
||||
# Fail fast when no API credentials are available at all
|
||||
sdk_env = _build_sdk_env(session_id=session_id, user_id=user_id)
|
||||
if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"):
|
||||
raise RuntimeError(
|
||||
"No API key configured. Set OPEN_ROUTER_API_KEY "
|
||||
"(or CHAT_API_KEY) for OpenRouter routing, "
|
||||
"or ANTHROPIC_API_KEY for direct Anthropic access."
|
||||
)
|
||||
|
||||
# --- Resume strategy: download transcript from bucket ---
|
||||
transcript_msg_count = 0 # watermark: session.messages length at upload
|
||||
mcp_server = create_copilot_mcp_server(use_e2b=use_e2b)
|
||||
|
||||
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
dl = await download_transcript(user_id, session_id)
|
||||
is_valid = bool(dl and validate_transcript(dl.content))
|
||||
if dl and is_valid:
|
||||
logger.info(
|
||||
f"[SDK] Transcript available for session {session_id}: "
|
||||
f"{len(dl.content)}B, msg_count={dl.message_count}"
|
||||
)
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
dl.content, session_id, sdk_cwd
|
||||
)
|
||||
if resume_file:
|
||||
use_resume = True
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.debug(
|
||||
f"[SDK] Using --resume ({len(dl.content)}B, "
|
||||
f"msg_count={transcript_msg_count})"
|
||||
)
|
||||
elif dl:
|
||||
logger.warning(
|
||||
f"[SDK] Transcript downloaded but invalid for {session_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] No transcript available for {session_id} "
|
||||
f"({len(session.messages)} messages in session)"
|
||||
)
|
||||
sdk_model = _resolve_sdk_model()
|
||||
|
||||
allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
sdk_options_kwargs: dict[str, Any] = {
|
||||
"system_prompt": system_prompt,
|
||||
"mcp_servers": {"copilot": mcp_server},
|
||||
"allowed_tools": allowed,
|
||||
"disallowed_tools": disallowed,
|
||||
"hooks": security_hooks,
|
||||
"cwd": sdk_cwd,
|
||||
"max_buffer_size": config.claude_agent_max_buffer_size,
|
||||
}
|
||||
if sdk_env:
|
||||
sdk_options_kwargs["model"] = sdk_model
|
||||
sdk_options_kwargs["env"] = sdk_env
|
||||
if use_resume and resume_file:
|
||||
sdk_options_kwargs["resume"] = resume_file
|
||||
|
||||
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type]
|
||||
|
||||
adapter = SDKResponseAdapter(message_id=message_id, session_id=session_id)
|
||||
|
||||
# Propagate user_id/session_id as OTEL context attributes so the
|
||||
# langsmith tracing integration attaches them to every span. This
|
||||
# is what Langfuse (or any OTEL backend) maps to its native
|
||||
# user/session fields.
|
||||
_otel_ctx = propagate_attributes(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
trace_name="copilot-sdk",
|
||||
tags=["sdk"],
|
||||
metadata={"resume": str(use_resume)},
|
||||
)
|
||||
_otel_ctx.__enter__()
|
||||
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
current_message = message or ""
|
||||
if not current_message and session.messages:
|
||||
last_user = [m for m in session.messages if m.role == "user"]
|
||||
if last_user:
|
||||
current_message = last_user[-1].content or ""
|
||||
|
||||
if not current_message.strip():
|
||||
yield StreamError(
|
||||
errorText="Message cannot be empty.",
|
||||
code="empty_prompt",
|
||||
)
|
||||
return
|
||||
|
||||
query_message, was_compacted = await _build_query_message(
|
||||
current_message,
|
||||
session,
|
||||
use_resume,
|
||||
transcript_msg_count,
|
||||
session_id,
|
||||
)
|
||||
# --- Transcript capture via Stop hook ---
|
||||
# Read the file content immediately — the SDK may clean up
|
||||
# the file before our finally block runs.
|
||||
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
|
||||
captured_transcript.path = transcript_path
|
||||
captured_transcript.sdk_session_id = sdk_session_id
|
||||
content = read_transcript_file(transcript_path)
|
||||
if content:
|
||||
captured_transcript.raw_content = content
|
||||
logger.info(
|
||||
"[SDK] [%s] Sending query — resume=%s, total_msgs=%d, query_len=%d",
|
||||
session_id[:12],
|
||||
use_resume,
|
||||
len(session.messages),
|
||||
len(query_message),
|
||||
f"[SDK] Stop hook: captured {len(content)}B from "
|
||||
f"{transcript_path}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] Stop hook: transcript file empty/missing at "
|
||||
f"{transcript_path}"
|
||||
)
|
||||
|
||||
compaction.reset_for_query()
|
||||
if was_compacted:
|
||||
for ev in compaction.emit_pre_query(session):
|
||||
yield ev
|
||||
await client.query(query_message, session_id=session_id)
|
||||
# Track SDK-internal compaction (PreCompact hook → start, next msg → end)
|
||||
compaction = CompactionTracker()
|
||||
|
||||
assistant_response = ChatMessage(role="assistant", content="")
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False
|
||||
security_hooks = create_security_hooks(
|
||||
user_id,
|
||||
sdk_cwd=sdk_cwd,
|
||||
max_subtasks=config.claude_agent_max_subtasks,
|
||||
on_stop=_on_stop if config.claude_agent_use_resume else None,
|
||||
on_compact=compaction.on_compact,
|
||||
)
|
||||
|
||||
# Use an explicit async iterator with non-cancelling heartbeats.
|
||||
# CRITICAL: we must NOT cancel __anext__() mid-flight — doing so
|
||||
# (via asyncio.timeout or wait_for) corrupts the SDK's internal
|
||||
# anyio memory stream, causing StopAsyncIteration on the next
|
||||
# call and silently dropping all in-flight tool results.
|
||||
# Instead, wrap __anext__() in a Task and use asyncio.wait()
|
||||
# with a timeout. On timeout we emit a heartbeat but keep the
|
||||
# Task alive so it can deliver the next message.
|
||||
msg_iter = client.receive_response().__aiter__()
|
||||
pending_task: asyncio.Task[Any] | None = None
|
||||
try:
|
||||
while not stream_completed:
|
||||
if pending_task is None:
|
||||
# --- Resume strategy: download transcript from bucket ---
|
||||
transcript_msg_count = 0 # watermark: session.messages length at upload
|
||||
|
||||
async def _next_msg() -> Any:
|
||||
return await msg_iter.__anext__()
|
||||
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
dl = await download_transcript(user_id, session_id)
|
||||
is_valid = bool(dl and validate_transcript(dl.content))
|
||||
if dl and is_valid:
|
||||
logger.info(
|
||||
f"[SDK] Transcript available for session {session_id}: "
|
||||
f"{len(dl.content)}B, msg_count={dl.message_count}"
|
||||
)
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
dl.content, session_id, sdk_cwd
|
||||
)
|
||||
if resume_file:
|
||||
use_resume = True
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.debug(
|
||||
f"[SDK] Using --resume ({len(dl.content)}B, "
|
||||
f"msg_count={transcript_msg_count})"
|
||||
)
|
||||
elif dl:
|
||||
logger.warning(
|
||||
f"[SDK] Transcript downloaded but invalid for {session_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] No transcript available for {session_id} "
|
||||
f"({len(session.messages)} messages in session)"
|
||||
)
|
||||
|
||||
pending_task = asyncio.create_task(_next_msg())
|
||||
allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
sdk_options_kwargs: dict[str, Any] = {
|
||||
"system_prompt": system_prompt,
|
||||
"mcp_servers": {"copilot": mcp_server},
|
||||
"allowed_tools": allowed,
|
||||
"disallowed_tools": disallowed,
|
||||
"hooks": security_hooks,
|
||||
"cwd": sdk_cwd,
|
||||
"max_buffer_size": config.claude_agent_max_buffer_size,
|
||||
}
|
||||
if sdk_model:
|
||||
sdk_options_kwargs["model"] = sdk_model
|
||||
if sdk_env:
|
||||
sdk_options_kwargs["env"] = sdk_env
|
||||
if use_resume and resume_file:
|
||||
sdk_options_kwargs["resume"] = resume_file
|
||||
|
||||
done, _ = await asyncio.wait(
|
||||
{pending_task}, timeout=_HEARTBEAT_INTERVAL
|
||||
)
|
||||
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
|
||||
if not done:
|
||||
await lock.refresh()
|
||||
for ev in compaction.emit_start_if_ready():
|
||||
yield ev
|
||||
yield StreamHeartbeat()
|
||||
continue
|
||||
adapter = SDKResponseAdapter(message_id=message_id, session_id=session_id)
|
||||
|
||||
# Task completed — get result
|
||||
pending_task = None
|
||||
try:
|
||||
sdk_msg = done.pop().result()
|
||||
except StopAsyncIteration:
|
||||
logger.info(
|
||||
"[SDK] [%s] Stream ended normally (StopAsyncIteration)",
|
||||
session_id[:12],
|
||||
)
|
||||
break
|
||||
except Exception as stream_err:
|
||||
# SDK sends {"type": "error"} which raises
|
||||
# Exception in receive_response() — capture it
|
||||
# so the session can still be saved and the
|
||||
# frontend gets a clean finish.
|
||||
logger.error(
|
||||
"[SDK] [%s] Stream error from SDK: %s",
|
||||
session_id[:12],
|
||||
stream_err,
|
||||
exc_info=True,
|
||||
)
|
||||
yield StreamError(
|
||||
errorText=f"SDK stream error: {stream_err}",
|
||||
code="sdk_stream_error",
|
||||
)
|
||||
break
|
||||
# Propagate user_id/session_id as OTEL context attributes so the
|
||||
# langsmith tracing integration attaches them to every span. This
|
||||
# is what Langfuse (or any OTEL backend) maps to its native
|
||||
# user/session fields.
|
||||
_otel_ctx = propagate_attributes(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
trace_name="copilot-sdk",
|
||||
tags=["sdk"],
|
||||
metadata={"resume": str(use_resume)},
|
||||
)
|
||||
_otel_ctx.__enter__()
|
||||
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
current_message = message or ""
|
||||
if not current_message and session.messages:
|
||||
last_user = [m for m in session.messages if m.role == "user"]
|
||||
if last_user:
|
||||
current_message = last_user[-1].content or ""
|
||||
|
||||
if not current_message.strip():
|
||||
yield StreamError(
|
||||
errorText="Message cannot be empty.",
|
||||
code="empty_prompt",
|
||||
)
|
||||
return
|
||||
|
||||
query_message, was_compacted = await _build_query_message(
|
||||
current_message,
|
||||
session,
|
||||
use_resume,
|
||||
transcript_msg_count,
|
||||
session_id,
|
||||
)
|
||||
logger.info(
|
||||
"[SDK] [%s] Sending query — resume=%s, total_msgs=%d, query_len=%d",
|
||||
session_id[:12],
|
||||
use_resume,
|
||||
len(session.messages),
|
||||
len(query_message),
|
||||
)
|
||||
|
||||
compaction.reset_for_query()
|
||||
if was_compacted:
|
||||
for ev in compaction.emit_pre_query(session):
|
||||
yield ev
|
||||
await client.query(query_message, session_id=session_id)
|
||||
|
||||
assistant_response = ChatMessage(role="assistant", content="")
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False
|
||||
ended_with_stream_error = False
|
||||
|
||||
# Use an explicit async iterator with non-cancelling heartbeats.
|
||||
# CRITICAL: we must NOT cancel __anext__() mid-flight — doing so
|
||||
# (via asyncio.timeout or wait_for) corrupts the SDK's internal
|
||||
# anyio memory stream, causing StopAsyncIteration on the next
|
||||
# call and silently dropping all in-flight tool results.
|
||||
# Instead, wrap __anext__() in a Task and use asyncio.wait()
|
||||
# with a timeout. On timeout we emit a heartbeat but keep the
|
||||
# Task alive so it can deliver the next message.
|
||||
msg_iter = client.receive_response().__aiter__()
|
||||
pending_task: asyncio.Task[Any] | None = None
|
||||
try:
|
||||
while not stream_completed:
|
||||
if pending_task is None:
|
||||
|
||||
async def _next_msg() -> Any:
|
||||
return await msg_iter.__anext__()
|
||||
|
||||
pending_task = asyncio.create_task(_next_msg())
|
||||
|
||||
done, _ = await asyncio.wait(
|
||||
{pending_task}, timeout=_HEARTBEAT_INTERVAL
|
||||
)
|
||||
|
||||
if not done:
|
||||
await lock.refresh()
|
||||
for ev in compaction.emit_start_if_ready():
|
||||
yield ev
|
||||
yield StreamHeartbeat()
|
||||
continue
|
||||
|
||||
# Task completed — get result
|
||||
pending_task = None
|
||||
try:
|
||||
sdk_msg = done.pop().result()
|
||||
except StopAsyncIteration:
|
||||
logger.info(
|
||||
"[SDK] [%s] Received: %s %s "
|
||||
"[SDK] [%s] Stream ended normally (StopAsyncIteration)",
|
||||
session_id[:12],
|
||||
)
|
||||
break
|
||||
except Exception as stream_err:
|
||||
# SDK sends {"type": "error"} which raises
|
||||
# Exception in receive_response() — capture it
|
||||
# so the session can still be saved and the
|
||||
# frontend gets a clean finish.
|
||||
logger.error(
|
||||
"[SDK] [%s] Stream error from SDK: %s",
|
||||
session_id[:12],
|
||||
stream_err,
|
||||
exc_info=True,
|
||||
)
|
||||
ended_with_stream_error = True
|
||||
yield StreamError(
|
||||
errorText=f"SDK stream error: {stream_err}",
|
||||
code="sdk_stream_error",
|
||||
)
|
||||
break
|
||||
|
||||
logger.info(
|
||||
"[SDK] [%s] Received: %s %s "
|
||||
"(unresolved=%d, current=%d, resolved=%d)",
|
||||
session_id[:12],
|
||||
type(sdk_msg).__name__,
|
||||
getattr(sdk_msg, "subtype", ""),
|
||||
len(adapter.current_tool_calls)
|
||||
- len(adapter.resolved_tool_calls),
|
||||
len(adapter.current_tool_calls),
|
||||
len(adapter.resolved_tool_calls),
|
||||
)
|
||||
|
||||
# Race-condition fix: SDK hooks (PostToolUse) are
|
||||
# executed asynchronously via start_soon() — the next
|
||||
# message can arrive before the hook stashes output.
|
||||
# wait_for_stash() awaits an asyncio.Event signaled by
|
||||
# stash_pending_tool_output(), completing as soon as
|
||||
# the hook finishes (typically <1ms). The sleep(0)
|
||||
# after lets any remaining concurrent hooks complete.
|
||||
#
|
||||
# Skip for parallel tool continuations: when the SDK
|
||||
# sends parallel tool calls as separate
|
||||
# AssistantMessages (each containing only
|
||||
# ToolUseBlocks), we must NOT wait/flush — the prior
|
||||
# tools are still executing concurrently.
|
||||
is_parallel_continuation = isinstance(
|
||||
sdk_msg, AssistantMessage
|
||||
) and all(isinstance(b, ToolUseBlock) for b in sdk_msg.content)
|
||||
|
||||
if (
|
||||
adapter.has_unresolved_tool_calls
|
||||
and isinstance(sdk_msg, (AssistantMessage, ResultMessage))
|
||||
and not is_parallel_continuation
|
||||
):
|
||||
if await wait_for_stash(timeout=0.5):
|
||||
await asyncio.sleep(0)
|
||||
else:
|
||||
logger.warning(
|
||||
"[SDK] [%s] Timed out waiting for "
|
||||
"PostToolUse hook stash "
|
||||
"(%d unresolved tool calls)",
|
||||
session_id[:12],
|
||||
len(adapter.current_tool_calls)
|
||||
- len(adapter.resolved_tool_calls),
|
||||
)
|
||||
|
||||
# Log ResultMessage details for debugging
|
||||
if isinstance(sdk_msg, ResultMessage):
|
||||
logger.info(
|
||||
"[SDK] [%s] Received: ResultMessage %s "
|
||||
"(unresolved=%d, current=%d, resolved=%d)",
|
||||
session_id[:12],
|
||||
type(sdk_msg).__name__,
|
||||
getattr(sdk_msg, "subtype", ""),
|
||||
sdk_msg.subtype,
|
||||
len(adapter.current_tool_calls)
|
||||
- len(adapter.resolved_tool_calls),
|
||||
len(adapter.current_tool_calls),
|
||||
len(adapter.resolved_tool_calls),
|
||||
)
|
||||
|
||||
# Race-condition fix: SDK hooks (PostToolUse) are
|
||||
# executed asynchronously via start_soon() — the next
|
||||
# message can arrive before the hook stashes output.
|
||||
# wait_for_stash() awaits an asyncio.Event signaled by
|
||||
# stash_pending_tool_output(), completing as soon as
|
||||
# the hook finishes (typically <1ms). The sleep(0)
|
||||
# after lets any remaining concurrent hooks complete.
|
||||
#
|
||||
# Skip for parallel tool continuations: when the SDK
|
||||
# sends parallel tool calls as separate
|
||||
# AssistantMessages (each containing only
|
||||
# ToolUseBlocks), we must NOT wait/flush — the prior
|
||||
# tools are still executing concurrently.
|
||||
is_parallel_continuation = isinstance(
|
||||
sdk_msg, AssistantMessage
|
||||
) and all(isinstance(b, ToolUseBlock) for b in sdk_msg.content)
|
||||
|
||||
if (
|
||||
adapter.has_unresolved_tool_calls
|
||||
and isinstance(sdk_msg, (AssistantMessage, ResultMessage))
|
||||
and not is_parallel_continuation
|
||||
):
|
||||
if await wait_for_stash(timeout=0.5):
|
||||
await asyncio.sleep(0)
|
||||
else:
|
||||
logger.warning(
|
||||
"[SDK] [%s] Timed out waiting for "
|
||||
"PostToolUse hook stash "
|
||||
"(%d unresolved tool calls)",
|
||||
session_id[:12],
|
||||
len(adapter.current_tool_calls)
|
||||
- len(adapter.resolved_tool_calls),
|
||||
)
|
||||
|
||||
# Log ResultMessage details for debugging
|
||||
if isinstance(sdk_msg, ResultMessage):
|
||||
logger.info(
|
||||
"[SDK] [%s] Received: ResultMessage %s "
|
||||
"(unresolved=%d, current=%d, resolved=%d)",
|
||||
if sdk_msg.subtype in ("error", "error_during_execution"):
|
||||
logger.error(
|
||||
"[SDK] [%s] SDK execution failed with error: %s",
|
||||
session_id[:12],
|
||||
sdk_msg.subtype,
|
||||
len(adapter.current_tool_calls)
|
||||
- len(adapter.resolved_tool_calls),
|
||||
len(adapter.current_tool_calls),
|
||||
len(adapter.resolved_tool_calls),
|
||||
sdk_msg.result or "(no error message provided)",
|
||||
)
|
||||
if sdk_msg.subtype in ("error", "error_during_execution"):
|
||||
logger.error(
|
||||
"[SDK] [%s] SDK execution failed with error: %s",
|
||||
session_id[:12],
|
||||
sdk_msg.result or "(no error message provided)",
|
||||
|
||||
# Emit compaction end if SDK finished compacting
|
||||
for ev in await compaction.emit_end_if_ready(session):
|
||||
yield ev
|
||||
|
||||
for response in adapter.convert_message(sdk_msg):
|
||||
if isinstance(response, StreamStart):
|
||||
continue
|
||||
|
||||
# Log tool events for debugging
|
||||
if isinstance(
|
||||
response,
|
||||
(
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
),
|
||||
):
|
||||
extra = ""
|
||||
if isinstance(response, StreamToolOutputAvailable):
|
||||
out_len = len(str(response.output))
|
||||
extra = f", output_len={out_len}"
|
||||
logger.info(
|
||||
"[SDK] [%s] Tool event: %s, tool=%s%s",
|
||||
session_id[:12],
|
||||
type(response).__name__,
|
||||
getattr(response, "toolName", "N/A"),
|
||||
extra,
|
||||
)
|
||||
|
||||
# Log errors being sent to frontend
|
||||
if isinstance(response, StreamError):
|
||||
logger.error(
|
||||
"[SDK] [%s] Sending error to frontend: %s (code=%s)",
|
||||
session_id[:12],
|
||||
response.errorText,
|
||||
response.code,
|
||||
)
|
||||
|
||||
yield response
|
||||
|
||||
if isinstance(response, StreamTextDelta):
|
||||
delta = response.delta or ""
|
||||
# After tool results, start a new assistant
|
||||
# message for the post-tool text.
|
||||
if has_tool_results and has_appended_assistant:
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant", content=delta
|
||||
)
|
||||
|
||||
# Emit compaction end if SDK finished compacting
|
||||
for ev in await compaction.emit_end_if_ready(session):
|
||||
yield ev
|
||||
|
||||
for response in adapter.convert_message(sdk_msg):
|
||||
if isinstance(response, StreamStart):
|
||||
continue
|
||||
|
||||
# Log tool events for debugging
|
||||
if isinstance(
|
||||
response,
|
||||
(
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
),
|
||||
):
|
||||
extra = ""
|
||||
if isinstance(response, StreamToolOutputAvailable):
|
||||
out_len = len(str(response.output))
|
||||
extra = f", output_len={out_len}"
|
||||
logger.info(
|
||||
"[SDK] [%s] Tool event: %s, tool=%s%s",
|
||||
session_id[:12],
|
||||
type(response).__name__,
|
||||
getattr(response, "toolName", "N/A"),
|
||||
extra,
|
||||
)
|
||||
|
||||
# Log errors being sent to frontend
|
||||
if isinstance(response, StreamError):
|
||||
logger.error(
|
||||
"[SDK] [%s] Sending error to frontend: %s (code=%s)",
|
||||
session_id[:12],
|
||||
response.errorText,
|
||||
response.code,
|
||||
)
|
||||
|
||||
yield response
|
||||
|
||||
if isinstance(response, StreamTextDelta):
|
||||
delta = response.delta or ""
|
||||
# After tool results, start a new assistant
|
||||
# message for the post-tool text.
|
||||
if has_tool_results and has_appended_assistant:
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant", content=delta
|
||||
)
|
||||
accumulated_tool_calls = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
else:
|
||||
assistant_response.content = (
|
||||
assistant_response.content or ""
|
||||
) + delta
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
elif isinstance(response, StreamToolInputAvailable):
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": response.toolCallId,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": response.toolName,
|
||||
"arguments": json.dumps(
|
||||
response.input or {}
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
accumulated_tool_calls = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
else:
|
||||
assistant_response.content = (
|
||||
assistant_response.content or ""
|
||||
) + delta
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
elif isinstance(response, StreamToolOutputAvailable):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=(
|
||||
response.output
|
||||
if isinstance(response.output, str)
|
||||
else str(response.output)
|
||||
),
|
||||
tool_call_id=response.toolCallId,
|
||||
)
|
||||
)
|
||||
has_tool_results = True
|
||||
|
||||
elif isinstance(response, StreamFinish):
|
||||
stream_completed = True
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Task/generator was cancelled (e.g. client disconnect,
|
||||
# server shutdown). Log and let the safety-net / finally
|
||||
# blocks handle cleanup.
|
||||
logger.warning(
|
||||
"[SDK] [%s] Streaming loop cancelled (asyncio.CancelledError)",
|
||||
session_id[:12],
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
# Cancel the pending __anext__ task to avoid a leaked
|
||||
# coroutine. This is safe even if the task already
|
||||
# completed.
|
||||
if pending_task is not None and not pending_task.done():
|
||||
pending_task.cancel()
|
||||
try:
|
||||
await pending_task
|
||||
except (asyncio.CancelledError, StopAsyncIteration):
|
||||
pass
|
||||
|
||||
# Safety net: if tools are still unresolved after the
|
||||
# streaming loop (e.g. StopAsyncIteration before ResultMessage,
|
||||
# or SDK not sending UserMessages for built-in tools), flush
|
||||
# them now so the frontend stops showing spinners.
|
||||
if adapter.has_unresolved_tool_calls:
|
||||
logger.warning(
|
||||
"[SDK] [%s] %d unresolved tool(s) after stream loop — "
|
||||
"flushing as safety net",
|
||||
session_id[:12],
|
||||
len(adapter.current_tool_calls)
|
||||
- len(adapter.resolved_tool_calls),
|
||||
)
|
||||
safety_responses: list[StreamBaseResponse] = []
|
||||
adapter._flush_unresolved_tool_calls(safety_responses)
|
||||
for response in safety_responses:
|
||||
if isinstance(
|
||||
response,
|
||||
(StreamToolInputAvailable, StreamToolOutputAvailable),
|
||||
):
|
||||
logger.info(
|
||||
"[SDK] [%s] Safety flush: %s, tool=%s",
|
||||
session_id[:12],
|
||||
type(response).__name__,
|
||||
getattr(response, "toolName", "N/A"),
|
||||
elif isinstance(response, StreamToolInputAvailable):
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": response.toolCallId,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": response.toolName,
|
||||
"arguments": json.dumps(response.input or {}),
|
||||
},
|
||||
}
|
||||
)
|
||||
yield response
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
# If the stream ended without a ResultMessage, the SDK
|
||||
# CLI exited unexpectedly or the user stopped execution.
|
||||
# Close any open text/step so chunks are well-formed, and
|
||||
# append a cancellation message so users see feedback.
|
||||
# StreamFinish is published by mark_session_completed in the processor.
|
||||
if not stream_completed:
|
||||
logger.info(
|
||||
"[SDK] [%s] Stream ended without ResultMessage (stopped by user)",
|
||||
session_id[:12],
|
||||
)
|
||||
closing_responses: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(closing_responses)
|
||||
for r in closing_responses:
|
||||
yield r
|
||||
elif isinstance(response, StreamToolOutputAvailable):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=(
|
||||
response.output
|
||||
if isinstance(response.output, str)
|
||||
else str(response.output)
|
||||
),
|
||||
tool_call_id=response.toolCallId,
|
||||
)
|
||||
)
|
||||
has_tool_results = True
|
||||
|
||||
# Add "Stopped by user" message so it persists after refresh
|
||||
# Use COPILOT_SYSTEM_PREFIX so frontend renders it as system message, not assistant
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=f"{COPILOT_SYSTEM_PREFIX} Execution stopped by user",
|
||||
elif isinstance(response, StreamFinish):
|
||||
stream_completed = True
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Task/generator was cancelled (e.g. client disconnect,
|
||||
# server shutdown). Log and let the safety-net / finally
|
||||
# blocks handle cleanup.
|
||||
logger.warning(
|
||||
"[SDK] [%s] Streaming loop cancelled (asyncio.CancelledError)",
|
||||
session_id[:12],
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
# Cancel the pending __anext__ task to avoid a leaked
|
||||
# coroutine. This is safe even if the task already
|
||||
# completed.
|
||||
if pending_task is not None and not pending_task.done():
|
||||
pending_task.cancel()
|
||||
try:
|
||||
await pending_task
|
||||
except (asyncio.CancelledError, StopAsyncIteration):
|
||||
# Expected: task was cancelled or exhausted during cleanup
|
||||
logger.info(
|
||||
"[SDK] Pending __anext__ task completed during cleanup"
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
assistant_response.content or assistant_response.tool_calls
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
# --- Upload transcript for next-turn --resume ---
|
||||
# After async with the SDK task group has exited, so the Stop
|
||||
# hook has already fired and the CLI has been SIGTERMed. The
|
||||
# CLI uses appendFileSync, so all writes are safely on disk.
|
||||
if config.claude_agent_use_resume and user_id:
|
||||
# With --resume the CLI appends to the resume file (most
|
||||
# complete). Otherwise use the Stop hook path.
|
||||
if use_resume and resume_file:
|
||||
raw_transcript = read_transcript_file(resume_file)
|
||||
logger.debug("[SDK] Transcript source: resume file")
|
||||
elif captured_transcript.path:
|
||||
raw_transcript = read_transcript_file(captured_transcript.path)
|
||||
logger.debug(
|
||||
"[SDK] Transcript source: stop hook (%s), read result: %s",
|
||||
captured_transcript.path,
|
||||
f"{len(raw_transcript)}B" if raw_transcript else "None",
|
||||
)
|
||||
else:
|
||||
raw_transcript = None
|
||||
|
||||
if not raw_transcript:
|
||||
logger.debug(
|
||||
"[SDK] No usable transcript — CLI file had no "
|
||||
"conversation entries (expected for first turn "
|
||||
"without --resume)"
|
||||
)
|
||||
|
||||
if raw_transcript:
|
||||
# Shield the upload from generator cancellation so a
|
||||
# client disconnect / page refresh doesn't lose the
|
||||
# transcript. The upload must finish even if the SSE
|
||||
# connection is torn down.
|
||||
await asyncio.shield(
|
||||
_try_upload_transcript(
|
||||
user_id,
|
||||
session_id,
|
||||
raw_transcript,
|
||||
message_count=len(session.messages),
|
||||
# Safety net: if tools are still unresolved after the
|
||||
# streaming loop (e.g. StopAsyncIteration before ResultMessage,
|
||||
# or SDK not sending UserMessages for built-in tools), flush
|
||||
# them now so the frontend stops showing spinners.
|
||||
if adapter.has_unresolved_tool_calls:
|
||||
logger.warning(
|
||||
"[SDK] [%s] %d unresolved tool(s) after stream loop — "
|
||||
"flushing as safety net",
|
||||
session_id[:12],
|
||||
len(adapter.current_tool_calls) - len(adapter.resolved_tool_calls),
|
||||
)
|
||||
safety_responses: list[StreamBaseResponse] = []
|
||||
adapter._flush_unresolved_tool_calls(safety_responses)
|
||||
for response in safety_responses:
|
||||
if isinstance(
|
||||
response,
|
||||
(StreamToolInputAvailable, StreamToolOutputAvailable),
|
||||
):
|
||||
logger.info(
|
||||
"[SDK] [%s] Safety flush: %s, tool=%s",
|
||||
session_id[:12],
|
||||
type(response).__name__,
|
||||
getattr(response, "toolName", "N/A"),
|
||||
)
|
||||
)
|
||||
yield response
|
||||
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
"claude-agent-sdk is not installed. "
|
||||
"Disable SDK mode (CHAT_USE_CLAUDE_AGENT_SDK=false) "
|
||||
"to use the OpenAI-compatible fallback."
|
||||
)
|
||||
# If the stream ended without a ResultMessage, the SDK
|
||||
# CLI exited unexpectedly or the user stopped execution.
|
||||
# Close any open text/step so chunks are well-formed, and
|
||||
# append a cancellation message so users see feedback.
|
||||
# StreamFinish is published by mark_session_completed in the processor.
|
||||
if not stream_completed and not ended_with_stream_error:
|
||||
logger.info(
|
||||
"[SDK] [%s] Stream ended without ResultMessage (stopped by user)",
|
||||
session_id[:12],
|
||||
)
|
||||
closing_responses: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(closing_responses)
|
||||
for r in closing_responses:
|
||||
yield r
|
||||
|
||||
# Add "Stopped by user" message so it persists after refresh
|
||||
# Use COPILOT_SYSTEM_PREFIX so frontend renders it as system message, not assistant
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=f"{COPILOT_SYSTEM_PREFIX} Execution stopped by user",
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
assistant_response.content or assistant_response.tool_calls
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
# --- Upload transcript for next-turn --resume ---
|
||||
# After async with the SDK task group has exited, so the Stop
|
||||
# hook has already fired and the CLI has been SIGTERMed. The
|
||||
# CLI uses appendFileSync, so all writes are safely on disk.
|
||||
if config.claude_agent_use_resume and user_id:
|
||||
# With --resume the CLI appends to the resume file (most
|
||||
# complete). Otherwise use the Stop hook path.
|
||||
if use_resume and resume_file:
|
||||
raw_transcript = read_transcript_file(resume_file)
|
||||
logger.debug("[SDK] Transcript source: resume file")
|
||||
elif captured_transcript.path:
|
||||
raw_transcript = read_transcript_file(captured_transcript.path)
|
||||
logger.debug(
|
||||
"[SDK] Transcript source: stop hook (%s), read result: %s",
|
||||
captured_transcript.path,
|
||||
f"{len(raw_transcript)}B" if raw_transcript else "None",
|
||||
)
|
||||
else:
|
||||
raw_transcript = None
|
||||
|
||||
if not raw_transcript:
|
||||
logger.debug(
|
||||
"[SDK] No usable transcript — CLI file had no "
|
||||
"conversation entries (expected for first turn "
|
||||
"without --resume)"
|
||||
)
|
||||
|
||||
if raw_transcript:
|
||||
# Shield the upload from generator cancellation so a
|
||||
# client disconnect / page refresh doesn't lose the
|
||||
# transcript. The upload must finish even if the SSE
|
||||
# connection is torn down.
|
||||
await asyncio.shield(
|
||||
_try_upload_transcript(
|
||||
user_id,
|
||||
session_id,
|
||||
raw_transcript,
|
||||
message_count=len(session.messages),
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[SDK] [%s] Stream completed successfully with %d messages",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,75 +4,14 @@ from os import getenv
|
||||
|
||||
import pytest
|
||||
|
||||
from . import service as chat_service
|
||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||
from .response_model import StreamError, StreamTextDelta, StreamToolOutputAvailable
|
||||
from .response_model import StreamError, StreamTextDelta
|
||||
from .sdk import service as sdk_service
|
||||
from .sdk.transcript import download_transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
|
||||
has_errors = False
|
||||
assistant_message = ""
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id, "Hello, how are you?", user_id=session.user_id
|
||||
):
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
assistant_message += chunk.delta
|
||||
|
||||
# StreamFinish is published by mark_session_completed (processor layer),
|
||||
# not by the service. The generator completing means the stream ended.
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert assistant_message, "Assistant message is empty"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
has_errors = False
|
||||
had_tool_calls = False
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id,
|
||||
"Please find me an agent that can help me with my business. Use the query 'moneny printing agent'",
|
||||
user_id=session.user_id,
|
||||
):
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
if isinstance(chunk, StreamToolOutputAvailable):
|
||||
had_tool_calls = True
|
||||
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert had_tool_calls, "Tool calls did not occur"
|
||||
session = await get_chat_session(session.session_id)
|
||||
assert session, "Session not found"
|
||||
assert session.usage, "Usage is empty"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
"""Test that the SDK --resume path captures and uses transcripts across turns.
|
||||
|
||||
Reference in New Issue
Block a user