feat(copilot): remove legacy copilot, add baseline non-SDK mode with tool calling (#12276)

## Summary
- Remove ~1200 lines of broken/unmaintained non-SDK copilot streaming
code (retry logic, parallel tool calls, context window management)
- Add `stream_chat_completion_baseline()` as a clean fallback LLM path
with full tool-calling support when `CHAT_USE_CLAUDE_AGENT_SDK=false`
(e.g. when Anthropic is down)
- Baseline reuses the same shared `TOOL_REGISTRY`,
`get_available_tools()`, and `execute_tool()` as the SDK path
- Move baseline code to dedicated `baseline/` folder (mirrors `sdk/`
structure)
- Clean up SDK service: remove unused params, fix model/env resolution,
fix stream error persistence
- Clean up config: remove `max_retries`, `thinking_enabled` fields
(non-SDK only)

## Changes
| File | Action |
|------|--------|
| `backend/copilot/baseline/__init__.py` | New — package export |
| `backend/copilot/baseline/service.py` | New — baseline streaming with
tool-call loop |
| `backend/copilot/baseline/service_test.py` | New — multi-turn keyword
recall test |
| `backend/copilot/service.py` | Remove ~1200 lines of legacy code, keep
shared helpers only |
| `backend/copilot/executor/processor.py` | Simplify branching to SDK vs
baseline |
| `backend/copilot/sdk/service.py` | Remove unused params, fix model/env
separation, fix stream error persistence |
| `backend/copilot/config.py` | Remove `max_retries`, `thinking_enabled`
|
| `backend/copilot/service_test.py` | Keep SDK test only (baseline test
moved) |
| `backend/copilot/parallel_tool_calls_test.py` | Deleted (tested
removed code) |

## Test plan
- [x] `poetry run format` passes
- [x] CI passes (all 3 Python versions, types, CodeQL)
- [ ] SDK path works unchanged in production
- [x] Baseline path (`CHAT_USE_CLAUDE_AGENT_SDK=false`) streams
responses with tool calling
- [x] Baseline emits correct Vercel AI SDK stream protocol events
This commit is contained in:
Zamil Majdy
2026-03-04 20:51:46 +07:00
committed by GitHub
parent 160d6eddfb
commit 0215332386
9 changed files with 971 additions and 2049 deletions

View File

@@ -0,0 +1,3 @@
from .service import stream_chat_completion_baseline
__all__ = ["stream_chat_completion_baseline"]

View File

@@ -0,0 +1,398 @@
"""Baseline LLM fallback — OpenAI-compatible streaming with tool calling.
Used when ``CHAT_USE_CLAUDE_AGENT_SDK=false``, e.g. as a fallback when the
Claude Agent SDK / Anthropic API is unavailable. Routes through any
OpenAI-compatible provider (OpenRouter by default) and reuses the same
shared tool registry as the SDK path.
"""
import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import Any
import orjson
from backend.copilot.model import (
ChatMessage,
ChatSession,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.copilot.service import (
_build_system_prompt,
_generate_session_title,
client,
config,
)
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from backend.util.prompt import compress_context
logger = logging.getLogger(__name__)
# Set to hold background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task[Any]] = set()
# Maximum number of tool-call rounds before forcing a text response.
_MAX_TOOL_ROUNDS = 30
async def _update_title_async(
session_id: str, message: str, user_id: str | None
) -> None:
"""Generate and persist a session title in the background."""
try:
title = await _generate_session_title(message, user_id, session_id)
if title:
await update_session_title(session_id, title)
except Exception as e:
logger.warning("[Baseline] Failed to update session title: %s", e)
async def _compress_session_messages(
messages: list[ChatMessage],
) -> list[ChatMessage]:
"""Compress session messages if they exceed the model's token limit.
Uses the shared compress_context() utility which supports LLM-based
summarization of older messages while keeping recent ones intact,
with progressive truncation and middle-out deletion as fallbacks.
"""
messages_dict = []
for msg in messages:
msg_dict: dict[str, Any] = {"role": msg.role}
if msg.content:
msg_dict["content"] = msg.content
messages_dict.append(msg_dict)
try:
result = await compress_context(
messages=messages_dict,
model=config.model,
client=client,
)
except Exception as e:
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
result = await compress_context(
messages=messages_dict,
model=config.model,
client=None,
)
if result.was_compacted:
logger.info(
"[Baseline] Context compacted: %d -> %d tokens "
"(%d summarized, %d dropped)",
result.original_token_count,
result.token_count,
result.messages_summarized,
result.messages_dropped,
)
return [
ChatMessage(role=m["role"], content=m.get("content"))
for m in result.messages
]
return messages
async def stream_chat_completion_baseline(
session_id: str,
message: str | None = None,
is_user_message: bool = True,
user_id: str | None = None,
session: ChatSession | None = None,
**_kwargs: Any,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Baseline LLM with tool calling via OpenAI-compatible API.
Designed as a fallback when the Claude Agent SDK is unavailable.
Uses the same tool registry as the SDK path but routes through any
OpenAI-compatible provider (e.g. OpenRouter).
Flow: stream response -> if tool_calls, execute them -> feed results back -> repeat.
"""
if session is None:
session = await get_chat_session(session_id, user_id)
if not session:
raise NotFoundError(
f"Session {session_id} not found. Please create a new session first."
)
# Append user message
new_role = "user" if is_user_message else "assistant"
if message and (
len(session.messages) == 0
or not (
session.messages[-1].role == new_role
and session.messages[-1].content == message
)
):
session.messages.append(ChatMessage(role=new_role, content=message))
if is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(message),
)
session = await upsert_chat_session(session)
# Generate title for new sessions
if is_user_message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
first_message = user_messages[0].content or message or ""
if first_message:
task = asyncio.create_task(
_update_title_async(session_id, first_message, user_id)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
message_id = str(uuid.uuid4())
# Build system prompt only on the first turn to avoid mid-conversation
# changes from concurrent chats updating business understanding.
is_first_turn = len(session.messages) <= 1
if is_first_turn:
system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=False
)
else:
system_prompt, _ = await _build_system_prompt(
user_id=None, has_conversation_history=True
)
# Compress context if approaching the model's token limit
messages_for_context = await _compress_session_messages(session.messages)
# Build OpenAI message list from session history
openai_messages: list[dict[str, Any]] = [
{"role": "system", "content": system_prompt}
]
for msg in messages_for_context:
if msg.role in ("user", "assistant") and msg.content:
openai_messages.append({"role": msg.role, "content": msg.content})
tools = get_available_tools()
yield StreamStart(messageId=message_id, sessionId=session_id)
assistant_text = ""
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
try:
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
yield StreamStartStep()
step_open = True
# Stream a response from the model
create_kwargs: dict[str, Any] = dict(
model=config.model,
messages=openai_messages,
stream=True,
)
if tools:
create_kwargs["tools"] = tools
response = await client.chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
# Accumulate streamed response (text + tool calls)
round_text = ""
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
# Text content
if delta.content:
if not text_started:
yield StreamTextStart(id=text_block_id)
text_started = True
round_text += delta.content
yield StreamTextDelta(id=text_block_id, delta=delta.content)
# Tool call fragments (streamed incrementally)
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_calls_by_index:
tool_calls_by_index[idx] = {
"id": "",
"name": "",
"arguments": "",
}
entry = tool_calls_by_index[idx]
if tc.id:
entry["id"] = tc.id
if tc.function and tc.function.name:
entry["name"] = tc.function.name
if tc.function and tc.function.arguments:
entry["arguments"] += tc.function.arguments
# Close text block if we had one this round
if text_started:
yield StreamTextEnd(id=text_block_id)
text_started = False
text_block_id = str(uuid.uuid4())
# Accumulate text for session persistence
assistant_text += round_text
# No tool calls -> model is done
if not tool_calls_by_index:
yield StreamFinishStep()
step_open = False
break
# Close step before tool execution
yield StreamFinishStep()
step_open = False
# Append the assistant message with tool_calls to context
assistant_msg: dict[str, Any] = {"role": "assistant"}
if round_text:
assistant_msg["content"] = round_text
assistant_msg["tool_calls"] = [
{
"id": tc["id"],
"type": "function",
"function": {
"name": tc["name"],
"arguments": tc["arguments"],
},
}
for tc in tool_calls_by_index.values()
]
openai_messages.append(assistant_msg)
# Execute each tool call and stream events
for tc in tool_calls_by_index.values():
tool_call_id = tc["id"]
tool_name = tc["name"]
raw_args = tc["arguments"] or "{}"
try:
tool_args = orjson.loads(raw_args)
except orjson.JSONDecodeError as parse_err:
parse_error = (
f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
)
logger.warning("[Baseline] %s", parse_error)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=parse_error,
success=False,
)
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": parse_error,
}
)
continue
yield StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
yield StreamToolInputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
input=tool_args,
)
# Execute via shared tool registry
try:
result: StreamToolOutputAvailable = await execute_tool(
tool_name=tool_name,
parameters=tool_args,
user_id=user_id,
session=session,
tool_call_id=tool_call_id,
)
yield result
tool_output = (
result.output
if isinstance(result.output, str)
else str(result.output)
)
except Exception as e:
error_output = f"Tool execution error: {e}"
logger.error(
"[Baseline] Tool %s failed: %s",
tool_name,
error_output,
exc_info=True,
)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=error_output,
success=False,
)
tool_output = error_output
# Append tool result to context for next round
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": tool_output,
}
)
else:
# for-loop exhausted without break -> tool-round limit hit
limit_msg = (
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
"without a final response."
)
logger.error("[Baseline] %s", limit_msg)
yield StreamError(
errorText=limit_msg,
code="baseline_tool_round_limit",
)
except Exception as e:
error_msg = str(e) or type(e).__name__
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
# Close any open text/step before emitting error
if text_started:
yield StreamTextEnd(id=text_block_id)
if step_open:
yield StreamFinishStep()
yield StreamError(errorText=error_msg, code="baseline_error")
# Still persist whatever we got
finally:
# Persist assistant response
if assistant_text:
session.messages.append(
ChatMessage(role="assistant", content=assistant_text)
)
try:
await upsert_chat_session(session)
except Exception as persist_err:
logger.error("[Baseline] Failed to persist session: %s", persist_err)
yield StreamFinish()

View File

@@ -0,0 +1,99 @@
import logging
from os import getenv
import pytest
from backend.copilot.baseline import stream_chat_completion_baseline
from backend.copilot.model import (
create_chat_session,
get_chat_session,
upsert_chat_session,
)
from backend.copilot.response_model import (
StreamError,
StreamFinish,
StreamStart,
StreamTextDelta,
)
logger = logging.getLogger(__name__)
@pytest.mark.asyncio(loop_scope="session")
async def test_baseline_multi_turn(setup_test_user, test_user_id):
"""Test that the baseline LLM path streams responses and maintains history.
Turn 1: Send a message with a unique keyword.
Turn 2: Ask the model to recall the keyword — proving conversation history
is correctly passed to the single-call LLM.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---
keyword = "QUASAR99"
turn1_msg = (
f"Please remember this special keyword: {keyword}. "
"Just confirm you've noted it, keep your response brief."
)
turn1_text = ""
turn1_errors: list[str] = []
got_start = False
got_finish = False
async for chunk in stream_chat_completion_baseline(
session.session_id,
turn1_msg,
user_id=test_user_id,
):
if isinstance(chunk, StreamStart):
got_start = True
elif isinstance(chunk, StreamTextDelta):
turn1_text += chunk.delta
elif isinstance(chunk, StreamError):
turn1_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
got_finish = True
assert got_start, "Turn 1 did not yield StreamStart"
assert got_finish, "Turn 1 did not yield StreamFinish"
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
assert turn1_text, "Turn 1 produced no text"
logger.info(f"Turn 1 response: {turn1_text[:100]}")
# Reload session for turn 2
session = await get_chat_session(session.session_id, test_user_id)
assert session, "Session not found after turn 1"
# Verify messages were persisted (user + assistant)
assert (
len(session.messages) >= 2
), f"Expected at least 2 messages after turn 1, got {len(session.messages)}"
# --- Turn 2: ask model to recall the keyword ---
turn2_msg = "What was the special keyword I asked you to remember?"
turn2_text = ""
turn2_errors: list[str] = []
async for chunk in stream_chat_completion_baseline(
session.session_id,
turn2_msg,
user_id=test_user_id,
session=session,
):
if isinstance(chunk, StreamTextDelta):
turn2_text += chunk.delta
elif isinstance(chunk, StreamError):
turn2_errors.append(chunk.errorText)
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
assert turn2_text, "Turn 2 produced no text"
assert keyword in turn2_text, (
f"Model did not recall keyword '{keyword}' in turn 2. "
f"Response: {turn2_text[:200]}"
)
logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}")

View File

@@ -26,11 +26,6 @@ class ChatConfig(BaseSettings):
# Session TTL Configuration - 12 hours
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# Streaming Configuration
max_retries: int = Field(
default=3,
description="Max retries for fallback path (SDK handles retries internally)",
)
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_schedules: int = Field(
default=30, description="Maximum number of agent schedules"
@@ -71,7 +66,7 @@ class ChatConfig(BaseSettings):
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
description="Use Claude Agent SDK for chat completions",
description="Use Claude Agent SDK (True) or OpenAI-compatible LLM baseline (False)",
)
claude_agent_model: str | None = Field(
default=None,
@@ -113,12 +108,6 @@ class ChatConfig(BaseSettings):
description="E2B sandbox keepalive timeout in seconds.",
)
# Extended thinking configuration for Claude models
thinking_enabled: bool = Field(
default=True,
description="Enable adaptive thinking for Claude models via OpenRouter",
)
@field_validator("use_e2b_sandbox", mode="before")
@classmethod
def get_use_e2b_sandbox(cls, v):

View File

@@ -9,8 +9,8 @@ import logging
import threading
import time
from backend.copilot import service as copilot_service
from backend.copilot import stream_registry
from backend.copilot.baseline import stream_chat_completion_baseline
from backend.copilot.config import ChatConfig
from backend.copilot.response_model import StreamFinish
from backend.copilot.sdk import service as sdk_service
@@ -194,7 +194,7 @@ class CoPilotProcessor:
):
"""Async execution logic for a CoPilot turn.
Calls the stream_chat_completion service function and publishes
Calls the chat completion service (SDK or baseline) and publishes
results to the stream registry.
Args:
@@ -218,9 +218,9 @@ class CoPilotProcessor:
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else copilot_service.stream_chat_completion
else stream_chat_completion_baseline
)
log.info(f"Using {'SDK' if use_sdk else 'standard'} service")
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
# Stream chat completion and publish chunks to Redis.
async for chunk in stream_fn(
@@ -228,7 +228,6 @@ class CoPilotProcessor:
message=entry.message if entry.message else None,
is_user_message=entry.is_user_message,
user_id=entry.user_id,
context=entry.context,
):
if cancel.is_set():
log.info("Cancel requested, breaking stream")

View File

@@ -1,269 +0,0 @@
"""Tests for parallel tool call execution in CoPilot.
These tests mock _yield_tool_call to avoid importing the full copilot stack
which requires Prisma, DB connections, etc.
"""
import asyncio
import time
from typing import Any, cast
import pytest
@pytest.mark.asyncio
async def test_parallel_tool_calls_run_concurrently():
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from backend.copilot.service import _execute_tool_calls_parallel
n_tools = 3
delay_per_tool = 0.2
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"tool_{i}", "arguments": "{}"},
}
for i in range(n_tools)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
original_yield = None
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"],
toolName=tc_list[idx]["function"]["name"],
input={},
)
await asyncio.sleep(delay_per_tool)
yield StreamToolOutputAvailable(
toolCallId=tc_list[idx]["id"],
toolName=tc_list[idx]["function"]["name"],
output="{}",
)
import backend.copilot.service as svc
original_yield = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
start = time.monotonic()
events = []
async for event in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
):
events.append(event)
elapsed = time.monotonic() - start
finally:
svc._yield_tool_call = original_yield
assert len(events) == n_tools * 2
# Parallel: should take ~delay, not ~n*delay
assert elapsed < delay_per_tool * (
n_tools - 0.5
), f"Took {elapsed:.2f}s, expected parallel (~{delay_per_tool}s)"
@pytest.mark.asyncio
async def test_single_tool_call_works():
"""Single tool call should work identically."""
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": "call_0",
"type": "function",
"function": {"name": "t", "arguments": "{}"},
}
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
events = [
e
async for e in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
)
]
finally:
svc._yield_tool_call = orig
assert len(events) == 2
@pytest.mark.asyncio
async def test_retryable_error_propagates():
"""Retryable errors should be raised after all tools finish."""
from backend.copilot.response_model import StreamToolOutputAvailable
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"t_{i}", "arguments": "{}"},
}
for i in range(2)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
async def fake_yield(tc_list, idx, sess):
if idx == 1:
raise KeyError("bad")
from backend.copilot.response_model import StreamToolInputAvailable
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName="t_0", input={}
)
await asyncio.sleep(0.05)
yield StreamToolOutputAvailable(
toolCallId=tc_list[idx]["id"], toolName="t_0", output="{}"
)
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
events = []
with pytest.raises(KeyError):
async for event in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
):
events.append(event)
# First tool's events should still be yielded
assert any(isinstance(e, StreamToolOutputAvailable) for e in events)
finally:
svc._yield_tool_call = orig
@pytest.mark.asyncio
async def test_session_shared_across_parallel_tools():
"""All parallel tools should receive the same session instance."""
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"t_{i}", "arguments": "{}"},
}
for i in range(3)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
observed_sessions = []
async def fake_yield(tc_list, idx, sess):
observed_sessions.append(sess)
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
)
yield StreamToolOutputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", output="{}"
)
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
async for _ in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
):
pass
finally:
svc._yield_tool_call = orig
assert len(observed_sessions) == 3
assert observed_sessions[0] is observed_sessions[1] is observed_sessions[2]
@pytest.mark.asyncio
async def test_cancellation_cleans_up():
"""Generator close should cancel in-flight tasks."""
from backend.copilot.response_model import StreamToolInputAvailable
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"t_{i}", "arguments": "{}"},
}
for i in range(2)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
started = asyncio.Event()
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
)
started.set()
await asyncio.sleep(10) # simulate long-running
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
gen = _execute_tool_calls_parallel(tool_calls, cast(Any, FakeSession()))
await gen.__anext__() # get first event
await started.wait()
await gen.aclose() # close generator
finally:
svc._yield_tool_call = orig
# If we get here without hanging, cleanup worked

View File

@@ -571,17 +571,12 @@ async def _build_query_message(
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
tool_call_response: str | None = None, # noqa: ARG001
is_user_message: bool = True,
user_id: str | None = None,
retry_count: int = 0, # noqa: ARG001
session: ChatSession | None = None,
context: dict[str, str] | None = None, # noqa: ARG001
**_kwargs: Any,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream chat completion using Claude Agent SDK.
Drop-in replacement for stream_chat_completion with improved reliability.
"""
"""Stream chat completion using Claude Agent SDK."""
if session is None:
session = await get_chat_session(session_id, user_id)
@@ -728,495 +723,491 @@ async def stream_chat_completion_sdk(
yield StreamStart(messageId=message_id, sessionId=session_id)
set_execution_context(user_id, session, sandbox=e2b_sandbox, sdk_cwd=sdk_cwd)
try:
# Fail fast when no API credentials are available at all
sdk_env = _build_sdk_env(session_id=session_id, user_id=user_id)
if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"):
raise RuntimeError(
"No API key configured. Set OPEN_ROUTER_API_KEY "
"(or CHAT_API_KEY) for OpenRouter routing, "
"or ANTHROPIC_API_KEY for direct Anthropic access."
)
mcp_server = create_copilot_mcp_server(use_e2b=use_e2b)
sdk_model = _resolve_sdk_model()
# --- Transcript capture via Stop hook ---
# Read the file content immediately — the SDK may clean up
# the file before our finally block runs.
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
captured_transcript.path = transcript_path
captured_transcript.sdk_session_id = sdk_session_id
content = read_transcript_file(transcript_path)
if content:
captured_transcript.raw_content = content
logger.info(
f"[SDK] Stop hook: captured {len(content)}B from "
f"{transcript_path}"
)
else:
logger.warning(
f"[SDK] Stop hook: transcript file empty/missing at "
f"{transcript_path}"
)
# Track SDK-internal compaction (PreCompact hook → start, next msg → end)
compaction = CompactionTracker()
security_hooks = create_security_hooks(
user_id,
sdk_cwd=sdk_cwd,
max_subtasks=config.claude_agent_max_subtasks,
on_stop=_on_stop if config.claude_agent_use_resume else None,
on_compact=compaction.on_compact,
# Fail fast when no API credentials are available at all
sdk_env = _build_sdk_env(session_id=session_id, user_id=user_id)
if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"):
raise RuntimeError(
"No API key configured. Set OPEN_ROUTER_API_KEY "
"(or CHAT_API_KEY) for OpenRouter routing, "
"or ANTHROPIC_API_KEY for direct Anthropic access."
)
# --- Resume strategy: download transcript from bucket ---
transcript_msg_count = 0 # watermark: session.messages length at upload
mcp_server = create_copilot_mcp_server(use_e2b=use_e2b)
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
dl = await download_transcript(user_id, session_id)
is_valid = bool(dl and validate_transcript(dl.content))
if dl and is_valid:
logger.info(
f"[SDK] Transcript available for session {session_id}: "
f"{len(dl.content)}B, msg_count={dl.message_count}"
)
resume_file = write_transcript_to_tempfile(
dl.content, session_id, sdk_cwd
)
if resume_file:
use_resume = True
transcript_msg_count = dl.message_count
logger.debug(
f"[SDK] Using --resume ({len(dl.content)}B, "
f"msg_count={transcript_msg_count})"
)
elif dl:
logger.warning(
f"[SDK] Transcript downloaded but invalid for {session_id}"
)
else:
logger.warning(
f"[SDK] No transcript available for {session_id} "
f"({len(session.messages)} messages in session)"
)
sdk_model = _resolve_sdk_model()
allowed = get_copilot_tool_names(use_e2b=use_e2b)
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
sdk_options_kwargs: dict[str, Any] = {
"system_prompt": system_prompt,
"mcp_servers": {"copilot": mcp_server},
"allowed_tools": allowed,
"disallowed_tools": disallowed,
"hooks": security_hooks,
"cwd": sdk_cwd,
"max_buffer_size": config.claude_agent_max_buffer_size,
}
if sdk_env:
sdk_options_kwargs["model"] = sdk_model
sdk_options_kwargs["env"] = sdk_env
if use_resume and resume_file:
sdk_options_kwargs["resume"] = resume_file
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type]
adapter = SDKResponseAdapter(message_id=message_id, session_id=session_id)
# Propagate user_id/session_id as OTEL context attributes so the
# langsmith tracing integration attaches them to every span. This
# is what Langfuse (or any OTEL backend) maps to its native
# user/session fields.
_otel_ctx = propagate_attributes(
user_id=user_id,
session_id=session_id,
trace_name="copilot-sdk",
tags=["sdk"],
metadata={"resume": str(use_resume)},
)
_otel_ctx.__enter__()
async with ClaudeSDKClient(options=options) as client:
current_message = message or ""
if not current_message and session.messages:
last_user = [m for m in session.messages if m.role == "user"]
if last_user:
current_message = last_user[-1].content or ""
if not current_message.strip():
yield StreamError(
errorText="Message cannot be empty.",
code="empty_prompt",
)
return
query_message, was_compacted = await _build_query_message(
current_message,
session,
use_resume,
transcript_msg_count,
session_id,
)
# --- Transcript capture via Stop hook ---
# Read the file content immediately — the SDK may clean up
# the file before our finally block runs.
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
captured_transcript.path = transcript_path
captured_transcript.sdk_session_id = sdk_session_id
content = read_transcript_file(transcript_path)
if content:
captured_transcript.raw_content = content
logger.info(
"[SDK] [%s] Sending query — resume=%s, total_msgs=%d, query_len=%d",
session_id[:12],
use_resume,
len(session.messages),
len(query_message),
f"[SDK] Stop hook: captured {len(content)}B from "
f"{transcript_path}"
)
else:
logger.warning(
f"[SDK] Stop hook: transcript file empty/missing at "
f"{transcript_path}"
)
compaction.reset_for_query()
if was_compacted:
for ev in compaction.emit_pre_query(session):
yield ev
await client.query(query_message, session_id=session_id)
# Track SDK-internal compaction (PreCompact hook → start, next msg → end)
compaction = CompactionTracker()
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
security_hooks = create_security_hooks(
user_id,
sdk_cwd=sdk_cwd,
max_subtasks=config.claude_agent_max_subtasks,
on_stop=_on_stop if config.claude_agent_use_resume else None,
on_compact=compaction.on_compact,
)
# Use an explicit async iterator with non-cancelling heartbeats.
# CRITICAL: we must NOT cancel __anext__() mid-flight — doing so
# (via asyncio.timeout or wait_for) corrupts the SDK's internal
# anyio memory stream, causing StopAsyncIteration on the next
# call and silently dropping all in-flight tool results.
# Instead, wrap __anext__() in a Task and use asyncio.wait()
# with a timeout. On timeout we emit a heartbeat but keep the
# Task alive so it can deliver the next message.
msg_iter = client.receive_response().__aiter__()
pending_task: asyncio.Task[Any] | None = None
try:
while not stream_completed:
if pending_task is None:
# --- Resume strategy: download transcript from bucket ---
transcript_msg_count = 0 # watermark: session.messages length at upload
async def _next_msg() -> Any:
return await msg_iter.__anext__()
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
dl = await download_transcript(user_id, session_id)
is_valid = bool(dl and validate_transcript(dl.content))
if dl and is_valid:
logger.info(
f"[SDK] Transcript available for session {session_id}: "
f"{len(dl.content)}B, msg_count={dl.message_count}"
)
resume_file = write_transcript_to_tempfile(
dl.content, session_id, sdk_cwd
)
if resume_file:
use_resume = True
transcript_msg_count = dl.message_count
logger.debug(
f"[SDK] Using --resume ({len(dl.content)}B, "
f"msg_count={transcript_msg_count})"
)
elif dl:
logger.warning(
f"[SDK] Transcript downloaded but invalid for {session_id}"
)
else:
logger.warning(
f"[SDK] No transcript available for {session_id} "
f"({len(session.messages)} messages in session)"
)
pending_task = asyncio.create_task(_next_msg())
allowed = get_copilot_tool_names(use_e2b=use_e2b)
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
sdk_options_kwargs: dict[str, Any] = {
"system_prompt": system_prompt,
"mcp_servers": {"copilot": mcp_server},
"allowed_tools": allowed,
"disallowed_tools": disallowed,
"hooks": security_hooks,
"cwd": sdk_cwd,
"max_buffer_size": config.claude_agent_max_buffer_size,
}
if sdk_model:
sdk_options_kwargs["model"] = sdk_model
if sdk_env:
sdk_options_kwargs["env"] = sdk_env
if use_resume and resume_file:
sdk_options_kwargs["resume"] = resume_file
done, _ = await asyncio.wait(
{pending_task}, timeout=_HEARTBEAT_INTERVAL
)
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type] # dynamic kwargs
if not done:
await lock.refresh()
for ev in compaction.emit_start_if_ready():
yield ev
yield StreamHeartbeat()
continue
adapter = SDKResponseAdapter(message_id=message_id, session_id=session_id)
# Task completed — get result
pending_task = None
try:
sdk_msg = done.pop().result()
except StopAsyncIteration:
logger.info(
"[SDK] [%s] Stream ended normally (StopAsyncIteration)",
session_id[:12],
)
break
except Exception as stream_err:
# SDK sends {"type": "error"} which raises
# Exception in receive_response() — capture it
# so the session can still be saved and the
# frontend gets a clean finish.
logger.error(
"[SDK] [%s] Stream error from SDK: %s",
session_id[:12],
stream_err,
exc_info=True,
)
yield StreamError(
errorText=f"SDK stream error: {stream_err}",
code="sdk_stream_error",
)
break
# Propagate user_id/session_id as OTEL context attributes so the
# langsmith tracing integration attaches them to every span. This
# is what Langfuse (or any OTEL backend) maps to its native
# user/session fields.
_otel_ctx = propagate_attributes(
user_id=user_id,
session_id=session_id,
trace_name="copilot-sdk",
tags=["sdk"],
metadata={"resume": str(use_resume)},
)
_otel_ctx.__enter__()
async with ClaudeSDKClient(options=options) as client:
current_message = message or ""
if not current_message and session.messages:
last_user = [m for m in session.messages if m.role == "user"]
if last_user:
current_message = last_user[-1].content or ""
if not current_message.strip():
yield StreamError(
errorText="Message cannot be empty.",
code="empty_prompt",
)
return
query_message, was_compacted = await _build_query_message(
current_message,
session,
use_resume,
transcript_msg_count,
session_id,
)
logger.info(
"[SDK] [%s] Sending query — resume=%s, total_msgs=%d, query_len=%d",
session_id[:12],
use_resume,
len(session.messages),
len(query_message),
)
compaction.reset_for_query()
if was_compacted:
for ev in compaction.emit_pre_query(session):
yield ev
await client.query(query_message, session_id=session_id)
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
ended_with_stream_error = False
# Use an explicit async iterator with non-cancelling heartbeats.
# CRITICAL: we must NOT cancel __anext__() mid-flight — doing so
# (via asyncio.timeout or wait_for) corrupts the SDK's internal
# anyio memory stream, causing StopAsyncIteration on the next
# call and silently dropping all in-flight tool results.
# Instead, wrap __anext__() in a Task and use asyncio.wait()
# with a timeout. On timeout we emit a heartbeat but keep the
# Task alive so it can deliver the next message.
msg_iter = client.receive_response().__aiter__()
pending_task: asyncio.Task[Any] | None = None
try:
while not stream_completed:
if pending_task is None:
async def _next_msg() -> Any:
return await msg_iter.__anext__()
pending_task = asyncio.create_task(_next_msg())
done, _ = await asyncio.wait(
{pending_task}, timeout=_HEARTBEAT_INTERVAL
)
if not done:
await lock.refresh()
for ev in compaction.emit_start_if_ready():
yield ev
yield StreamHeartbeat()
continue
# Task completed — get result
pending_task = None
try:
sdk_msg = done.pop().result()
except StopAsyncIteration:
logger.info(
"[SDK] [%s] Received: %s %s "
"[SDK] [%s] Stream ended normally (StopAsyncIteration)",
session_id[:12],
)
break
except Exception as stream_err:
# SDK sends {"type": "error"} which raises
# Exception in receive_response() — capture it
# so the session can still be saved and the
# frontend gets a clean finish.
logger.error(
"[SDK] [%s] Stream error from SDK: %s",
session_id[:12],
stream_err,
exc_info=True,
)
ended_with_stream_error = True
yield StreamError(
errorText=f"SDK stream error: {stream_err}",
code="sdk_stream_error",
)
break
logger.info(
"[SDK] [%s] Received: %s %s "
"(unresolved=%d, current=%d, resolved=%d)",
session_id[:12],
type(sdk_msg).__name__,
getattr(sdk_msg, "subtype", ""),
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
len(adapter.current_tool_calls),
len(adapter.resolved_tool_calls),
)
# Race-condition fix: SDK hooks (PostToolUse) are
# executed asynchronously via start_soon() — the next
# message can arrive before the hook stashes output.
# wait_for_stash() awaits an asyncio.Event signaled by
# stash_pending_tool_output(), completing as soon as
# the hook finishes (typically <1ms). The sleep(0)
# after lets any remaining concurrent hooks complete.
#
# Skip for parallel tool continuations: when the SDK
# sends parallel tool calls as separate
# AssistantMessages (each containing only
# ToolUseBlocks), we must NOT wait/flush — the prior
# tools are still executing concurrently.
is_parallel_continuation = isinstance(
sdk_msg, AssistantMessage
) and all(isinstance(b, ToolUseBlock) for b in sdk_msg.content)
if (
adapter.has_unresolved_tool_calls
and isinstance(sdk_msg, (AssistantMessage, ResultMessage))
and not is_parallel_continuation
):
if await wait_for_stash(timeout=0.5):
await asyncio.sleep(0)
else:
logger.warning(
"[SDK] [%s] Timed out waiting for "
"PostToolUse hook stash "
"(%d unresolved tool calls)",
session_id[:12],
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
)
# Log ResultMessage details for debugging
if isinstance(sdk_msg, ResultMessage):
logger.info(
"[SDK] [%s] Received: ResultMessage %s "
"(unresolved=%d, current=%d, resolved=%d)",
session_id[:12],
type(sdk_msg).__name__,
getattr(sdk_msg, "subtype", ""),
sdk_msg.subtype,
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
len(adapter.current_tool_calls),
len(adapter.resolved_tool_calls),
)
# Race-condition fix: SDK hooks (PostToolUse) are
# executed asynchronously via start_soon() — the next
# message can arrive before the hook stashes output.
# wait_for_stash() awaits an asyncio.Event signaled by
# stash_pending_tool_output(), completing as soon as
# the hook finishes (typically <1ms). The sleep(0)
# after lets any remaining concurrent hooks complete.
#
# Skip for parallel tool continuations: when the SDK
# sends parallel tool calls as separate
# AssistantMessages (each containing only
# ToolUseBlocks), we must NOT wait/flush — the prior
# tools are still executing concurrently.
is_parallel_continuation = isinstance(
sdk_msg, AssistantMessage
) and all(isinstance(b, ToolUseBlock) for b in sdk_msg.content)
if (
adapter.has_unresolved_tool_calls
and isinstance(sdk_msg, (AssistantMessage, ResultMessage))
and not is_parallel_continuation
):
if await wait_for_stash(timeout=0.5):
await asyncio.sleep(0)
else:
logger.warning(
"[SDK] [%s] Timed out waiting for "
"PostToolUse hook stash "
"(%d unresolved tool calls)",
session_id[:12],
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
)
# Log ResultMessage details for debugging
if isinstance(sdk_msg, ResultMessage):
logger.info(
"[SDK] [%s] Received: ResultMessage %s "
"(unresolved=%d, current=%d, resolved=%d)",
if sdk_msg.subtype in ("error", "error_during_execution"):
logger.error(
"[SDK] [%s] SDK execution failed with error: %s",
session_id[:12],
sdk_msg.subtype,
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
len(adapter.current_tool_calls),
len(adapter.resolved_tool_calls),
sdk_msg.result or "(no error message provided)",
)
if sdk_msg.subtype in ("error", "error_during_execution"):
logger.error(
"[SDK] [%s] SDK execution failed with error: %s",
session_id[:12],
sdk_msg.result or "(no error message provided)",
# Emit compaction end if SDK finished compacting
for ev in await compaction.emit_end_if_ready(session):
yield ev
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
# Log tool events for debugging
if isinstance(
response,
(
StreamToolInputAvailable,
StreamToolOutputAvailable,
),
):
extra = ""
if isinstance(response, StreamToolOutputAvailable):
out_len = len(str(response.output))
extra = f", output_len={out_len}"
logger.info(
"[SDK] [%s] Tool event: %s, tool=%s%s",
session_id[:12],
type(response).__name__,
getattr(response, "toolName", "N/A"),
extra,
)
# Log errors being sent to frontend
if isinstance(response, StreamError):
logger.error(
"[SDK] [%s] Sending error to frontend: %s (code=%s)",
session_id[:12],
response.errorText,
response.code,
)
yield response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
# After tool results, start a new assistant
# message for the post-tool text.
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
# Emit compaction end if SDK finished compacting
for ev in await compaction.emit_end_if_ready(session):
yield ev
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
# Log tool events for debugging
if isinstance(
response,
(
StreamToolInputAvailable,
StreamToolOutputAvailable,
),
):
extra = ""
if isinstance(response, StreamToolOutputAvailable):
out_len = len(str(response.output))
extra = f", output_len={out_len}"
logger.info(
"[SDK] [%s] Tool event: %s, tool=%s%s",
session_id[:12],
type(response).__name__,
getattr(response, "toolName", "N/A"),
extra,
)
# Log errors being sent to frontend
if isinstance(response, StreamError):
logger.error(
"[SDK] [%s] Sending error to frontend: %s (code=%s)",
session_id[:12],
response.errorText,
response.code,
)
yield response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
# After tool results, start a new assistant
# message for the post-tool text.
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
accumulated_tool_calls = []
has_appended_assistant = False
has_tool_results = False
session.messages.append(assistant_response)
has_appended_assistant = True
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(
response.input or {}
),
},
}
)
assistant_response.tool_calls = accumulated_tool_calls
accumulated_tool_calls = []
has_appended_assistant = False
has_tool_results = False
session.messages.append(assistant_response)
has_appended_assistant = True
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
tool_call_id=response.toolCallId,
)
)
has_tool_results = True
elif isinstance(response, StreamFinish):
stream_completed = True
except asyncio.CancelledError:
# Task/generator was cancelled (e.g. client disconnect,
# server shutdown). Log and let the safety-net / finally
# blocks handle cleanup.
logger.warning(
"[SDK] [%s] Streaming loop cancelled (asyncio.CancelledError)",
session_id[:12],
)
raise
finally:
# Cancel the pending __anext__ task to avoid a leaked
# coroutine. This is safe even if the task already
# completed.
if pending_task is not None and not pending_task.done():
pending_task.cancel()
try:
await pending_task
except (asyncio.CancelledError, StopAsyncIteration):
pass
# Safety net: if tools are still unresolved after the
# streaming loop (e.g. StopAsyncIteration before ResultMessage,
# or SDK not sending UserMessages for built-in tools), flush
# them now so the frontend stops showing spinners.
if adapter.has_unresolved_tool_calls:
logger.warning(
"[SDK] [%s] %d unresolved tool(s) after stream loop — "
"flushing as safety net",
session_id[:12],
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
)
safety_responses: list[StreamBaseResponse] = []
adapter._flush_unresolved_tool_calls(safety_responses)
for response in safety_responses:
if isinstance(
response,
(StreamToolInputAvailable, StreamToolOutputAvailable),
):
logger.info(
"[SDK] [%s] Safety flush: %s, tool=%s",
session_id[:12],
type(response).__name__,
getattr(response, "toolName", "N/A"),
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(response.input or {}),
},
}
)
yield response
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
# If the stream ended without a ResultMessage, the SDK
# CLI exited unexpectedly or the user stopped execution.
# Close any open text/step so chunks are well-formed, and
# append a cancellation message so users see feedback.
# StreamFinish is published by mark_session_completed in the processor.
if not stream_completed:
logger.info(
"[SDK] [%s] Stream ended without ResultMessage (stopped by user)",
session_id[:12],
)
closing_responses: list[StreamBaseResponse] = []
adapter._end_text_if_open(closing_responses)
for r in closing_responses:
yield r
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
tool_call_id=response.toolCallId,
)
)
has_tool_results = True
# Add "Stopped by user" message so it persists after refresh
# Use COPILOT_SYSTEM_PREFIX so frontend renders it as system message, not assistant
session.messages.append(
ChatMessage(
role="assistant",
content=f"{COPILOT_SYSTEM_PREFIX} Execution stopped by user",
elif isinstance(response, StreamFinish):
stream_completed = True
except asyncio.CancelledError:
# Task/generator was cancelled (e.g. client disconnect,
# server shutdown). Log and let the safety-net / finally
# blocks handle cleanup.
logger.warning(
"[SDK] [%s] Streaming loop cancelled (asyncio.CancelledError)",
session_id[:12],
)
raise
finally:
# Cancel the pending __anext__ task to avoid a leaked
# coroutine. This is safe even if the task already
# completed.
if pending_task is not None and not pending_task.done():
pending_task.cancel()
try:
await pending_task
except (asyncio.CancelledError, StopAsyncIteration):
# Expected: task was cancelled or exhausted during cleanup
logger.info(
"[SDK] Pending __anext__ task completed during cleanup"
)
)
if (
assistant_response.content or assistant_response.tool_calls
) and not has_appended_assistant:
session.messages.append(assistant_response)
# --- Upload transcript for next-turn --resume ---
# After async with the SDK task group has exited, so the Stop
# hook has already fired and the CLI has been SIGTERMed. The
# CLI uses appendFileSync, so all writes are safely on disk.
if config.claude_agent_use_resume and user_id:
# With --resume the CLI appends to the resume file (most
# complete). Otherwise use the Stop hook path.
if use_resume and resume_file:
raw_transcript = read_transcript_file(resume_file)
logger.debug("[SDK] Transcript source: resume file")
elif captured_transcript.path:
raw_transcript = read_transcript_file(captured_transcript.path)
logger.debug(
"[SDK] Transcript source: stop hook (%s), read result: %s",
captured_transcript.path,
f"{len(raw_transcript)}B" if raw_transcript else "None",
)
else:
raw_transcript = None
if not raw_transcript:
logger.debug(
"[SDK] No usable transcript — CLI file had no "
"conversation entries (expected for first turn "
"without --resume)"
)
if raw_transcript:
# Shield the upload from generator cancellation so a
# client disconnect / page refresh doesn't lose the
# transcript. The upload must finish even if the SSE
# connection is torn down.
await asyncio.shield(
_try_upload_transcript(
user_id,
session_id,
raw_transcript,
message_count=len(session.messages),
# Safety net: if tools are still unresolved after the
# streaming loop (e.g. StopAsyncIteration before ResultMessage,
# or SDK not sending UserMessages for built-in tools), flush
# them now so the frontend stops showing spinners.
if adapter.has_unresolved_tool_calls:
logger.warning(
"[SDK] [%s] %d unresolved tool(s) after stream loop — "
"flushing as safety net",
session_id[:12],
len(adapter.current_tool_calls) - len(adapter.resolved_tool_calls),
)
safety_responses: list[StreamBaseResponse] = []
adapter._flush_unresolved_tool_calls(safety_responses)
for response in safety_responses:
if isinstance(
response,
(StreamToolInputAvailable, StreamToolOutputAvailable),
):
logger.info(
"[SDK] [%s] Safety flush: %s, tool=%s",
session_id[:12],
type(response).__name__,
getattr(response, "toolName", "N/A"),
)
)
yield response
except ImportError:
raise RuntimeError(
"claude-agent-sdk is not installed. "
"Disable SDK mode (CHAT_USE_CLAUDE_AGENT_SDK=false) "
"to use the OpenAI-compatible fallback."
)
# If the stream ended without a ResultMessage, the SDK
# CLI exited unexpectedly or the user stopped execution.
# Close any open text/step so chunks are well-formed, and
# append a cancellation message so users see feedback.
# StreamFinish is published by mark_session_completed in the processor.
if not stream_completed and not ended_with_stream_error:
logger.info(
"[SDK] [%s] Stream ended without ResultMessage (stopped by user)",
session_id[:12],
)
closing_responses: list[StreamBaseResponse] = []
adapter._end_text_if_open(closing_responses)
for r in closing_responses:
yield r
# Add "Stopped by user" message so it persists after refresh
# Use COPILOT_SYSTEM_PREFIX so frontend renders it as system message, not assistant
session.messages.append(
ChatMessage(
role="assistant",
content=f"{COPILOT_SYSTEM_PREFIX} Execution stopped by user",
)
)
if (
assistant_response.content or assistant_response.tool_calls
) and not has_appended_assistant:
session.messages.append(assistant_response)
# --- Upload transcript for next-turn --resume ---
# After async with the SDK task group has exited, so the Stop
# hook has already fired and the CLI has been SIGTERMed. The
# CLI uses appendFileSync, so all writes are safely on disk.
if config.claude_agent_use_resume and user_id:
# With --resume the CLI appends to the resume file (most
# complete). Otherwise use the Stop hook path.
if use_resume and resume_file:
raw_transcript = read_transcript_file(resume_file)
logger.debug("[SDK] Transcript source: resume file")
elif captured_transcript.path:
raw_transcript = read_transcript_file(captured_transcript.path)
logger.debug(
"[SDK] Transcript source: stop hook (%s), read result: %s",
captured_transcript.path,
f"{len(raw_transcript)}B" if raw_transcript else "None",
)
else:
raw_transcript = None
if not raw_transcript:
logger.debug(
"[SDK] No usable transcript — CLI file had no "
"conversation entries (expected for first turn "
"without --resume)"
)
if raw_transcript:
# Shield the upload from generator cancellation so a
# client disconnect / page refresh doesn't lose the
# transcript. The upload must finish even if the SSE
# connection is torn down.
await asyncio.shield(
_try_upload_transcript(
user_id,
session_id,
raw_transcript,
message_count=len(session.messages),
)
)
logger.info(
"[SDK] [%s] Stream completed successfully with %d messages",

File diff suppressed because it is too large Load Diff

View File

@@ -4,75 +4,14 @@ from os import getenv
import pytest
from . import service as chat_service
from .model import create_chat_session, get_chat_session, upsert_chat_session
from .response_model import StreamError, StreamTextDelta, StreamToolOutputAvailable
from .response_model import StreamError, StreamTextDelta
from .sdk import service as sdk_service
from .sdk.transcript import download_transcript
logger = logging.getLogger(__name__)
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion(setup_test_user, test_user_id):
"""
Test the stream_chat_completion function.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
has_errors = False
assistant_message = ""
async for chunk in chat_service.stream_chat_completion(
session.session_id, "Hello, how are you?", user_id=session.user_id
):
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamTextDelta):
assistant_message += chunk.delta
# StreamFinish is published by mark_session_completed (processor layer),
# not by the service. The generator completing means the stream ended.
assert not has_errors, "Error occurred while streaming chat completion"
assert assistant_message, "Assistant message is empty"
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
"""
Test the stream_chat_completion function.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
has_errors = False
had_tool_calls = False
async for chunk in chat_service.stream_chat_completion(
session.session_id,
"Please find me an agent that can help me with my business. Use the query 'moneny printing agent'",
user_id=session.user_id,
):
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamToolOutputAvailable):
had_tool_calls = True
assert not has_errors, "Error occurred while streaming chat completion"
assert had_tool_calls, "Tool calls did not occur"
session = await get_chat_session(session.session_id)
assert session, "Session not found"
assert session.usage, "Usage is empty"
@pytest.mark.asyncio(loop_scope="session")
async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
"""Test that the SDK --resume path captures and uses transcripts across turns.