Compare commits

..

10 Commits

Author SHA1 Message Date
Zamil Majdy
ecfe4e6a7a fix(copilot): RPC DataError reconstruction, chat stream reconnection
Fix two issues:

1. RPC DataError deserialization crash: When the database-manager
   returns a 400 for a Prisma DataError/UniqueViolationError, the
   client-side reconstruction crashes because DataError.__init__
   expects a dict but exc.args only contains a string message.
   Wrap the string in the expected dict structure so the exception
   is properly caught by callers (e.g. workspace file overwrites).

2. Chat stream reconnection on page refresh: The AI SDK's built-in
   resume:true fires before message hydration completes, causing
   hydrated messages to overwrite the resumed stream. Replace with
   manual resumeStream() called after hydration + active_stream
   detection. Show the stop button immediately when an active stream
   is detected (isReconnecting flag) and prevent sending new messages
   until reconnected.
2026-02-19 15:50:51 +08:00
Otto (AGPT)
efb4b3b518 fix: Update _pending_tool_outputs type to dict[str, list[str]] 2026-02-19 02:42:05 +00:00
Otto (AGPT)
ebeab7fbe6 fix(copilot): Address GenericTool review comments
- Fix parseOutput treating arrays as objects (skip Array.isArray)
- Add React import for React.ReactNode type reference
- Differentiate web_fetch vs WebSearch title in accordion
2026-02-19 02:15:52 +00:00
Otto (AGPT)
98ef8a26ab fix(copilot): Address new review comments
- Guard metadata store() with try/except so failure doesn't orphan the
  already-uploaded transcript (coderabbit Major)
- Fix OrbitLoader size from 20 to 14 to match static icons
- Filter output.files to confirmed strings instead of unchecked cast
2026-02-19 01:57:47 +00:00
Otto (AGPT)
ed02e6db9e style: format GenericTool.tsx with prettier 2026-02-19 01:56:36 +00:00
Otto (AGPT)
6952334b85 fix(copilot): Address remaining review comments
- Tool output stashing: use FIFO queue per tool name instead of single
  value, so duplicate calls to the same tool in one turn each get their
  own output (fixes sentry HIGH/MEDIUM)
- Web accordion: show JSON fallback when output has no recognized text
  fields (fixes empty accordion body edge case)
- Cleanup dir logging: log when project dir not found
- Flush behavior and TodoItem cast are already correct as-is
2026-02-19 00:37:13 +00:00
Otto (AGPT)
0c586c2edf fix(copilot): Address PR review comments
- Shield transcript upload and session save from generator cancellation
  (asyncio.shield) so page refresh/disconnect doesn't lose the transcript
- Return content_base64 for small image files (not just text) so
  _extract_image_block can actually work
- Add 32KB size limit to _extract_image_block to prevent oversized images
- Fix gap fill when transcript_msg_count == 0 (metadata absent)
- Add truncation to files.join in GenericTool.tsx
2026-02-19 00:30:06 +00:00
Zamil Majdy
b6128dd75f feat(copilot): stream resume, transcript staleness detection, WebSearch display
- Enable `resume: true` on `useChat` with `prepareReconnectToStreamRequest`
  so page refresh reconnects to active backend streams via Redis replay
- Add `message_count` watermark + timestamp metadata to transcript uploads;
  on download, detect staleness and compress only the gap instead of the
  full history (hybrid: transcript via --resume + compressed missed turns)
- Fix WebSearch accordion showing empty by extracting text from MCP-style
  content blocks (`extractMcpText`) with raw JSON fallback
- Revert over-blocking: only `AskUserQuestion` added to SDK_DISALLOWED_TOOLS
  (removed EnterPlanMode, ExitPlanMode, Skill, NotebookEdit)
- Add defensive TodoItem filter per coderabbit review
- Fix service_test for TranscriptDownload return type change
2026-02-19 05:09:41 +05:30
Zamil Majdy
c4f5f7c8b8 Merge branch 'dev' into copilot/sdk-improvements 2026-02-19 00:14:23 +05:30
Zamil Majdy
8af4e0bf7d feat(copilot): SDK tool output, transcript resume, image support, GenericTool UI
- Fix SDK built-in tool outputs (WebSearch, Read, TodoWrite) not showing
  in frontend by stashing outputs via PostToolUse hook and flushing
  unresolved tool calls in response adapter
- Fix transcript-based --resume for multi-turn conversations: single
  clean upload block after async with, extracted _try_upload_transcript
  helper, removed redundant dual-strategy code
- Add image support in MCP tool results: detect workspace file responses
  with content_base64 and return MCP image content blocks so Claude can
  "see" small images (<32KB)
- Overhaul GenericTool.tsx with tool-specific icons, TodoWrite checklist
  rendering, WebSearch text display, and proper accordion content
- Downgrade 36 per-operation [TIMING]/[TASK_LOOKUP] diagnostic logs from
  info to debug in stream_registry.py and service.py
- Fix silent exceptions: add warning logs for swallowed ValueError/
  TypeError in stream_registry and Exception in service long-running path
- Clean up transcript.py: remove dead find_cli_transcript and
  read_fallback_transcript functions, simplify logging
2026-02-19 00:11:08 +05:30
17 changed files with 1313 additions and 179 deletions

View File

@@ -4,7 +4,6 @@ This module contains the CoPilotExecutor class that consumes chat tasks from
RabbitMQ and processes them using a thread pool, following the graph executor pattern. RabbitMQ and processes them using a thread pool, following the graph executor pattern.
""" """
import asyncio
import logging import logging
import os import os
import threading import threading

View File

@@ -53,6 +53,7 @@ class SDKResponseAdapter:
self.has_started_text = False self.has_started_text = False
self.has_ended_text = False self.has_ended_text = False
self.current_tool_calls: dict[str, dict[str, str]] = {} self.current_tool_calls: dict[str, dict[str, str]] = {}
self.resolved_tool_calls: set[str] = set()
self.task_id: str | None = None self.task_id: str | None = None
self.step_open = False self.step_open = False
@@ -74,6 +75,10 @@ class SDKResponseAdapter:
self.step_open = True self.step_open = True
elif isinstance(sdk_message, AssistantMessage): elif isinstance(sdk_message, AssistantMessage):
# Flush any SDK built-in tool calls that didn't get a UserMessage
# result (e.g. WebSearch, Read handled internally by the CLI).
self._flush_unresolved_tool_calls(responses)
# After tool results, the SDK sends a new AssistantMessage for the # After tool results, the SDK sends a new AssistantMessage for the
# next LLM turn. Open a new step if the previous one was closed. # next LLM turn. Open a new step if the previous one was closed.
if not self.step_open: if not self.step_open:
@@ -111,6 +116,8 @@ class SDKResponseAdapter:
# UserMessage carries tool results back from tool execution. # UserMessage carries tool results back from tool execution.
content = sdk_message.content content = sdk_message.content
blocks = content if isinstance(content, list) else [] blocks = content if isinstance(content, list) else []
resolved_in_blocks: set[str] = set()
for block in blocks: for block in blocks:
if isinstance(block, ToolResultBlock) and block.tool_use_id: if isinstance(block, ToolResultBlock) and block.tool_use_id:
tool_info = self.current_tool_calls.get(block.tool_use_id, {}) tool_info = self.current_tool_calls.get(block.tool_use_id, {})
@@ -132,6 +139,37 @@ class SDKResponseAdapter:
success=not (block.is_error or False), success=not (block.is_error or False),
) )
) )
resolved_in_blocks.add(block.tool_use_id)
# Handle SDK built-in tool results carried via parent_tool_use_id
# instead of (or in addition to) ToolResultBlock content.
parent_id = sdk_message.parent_tool_use_id
if parent_id and parent_id not in resolved_in_blocks:
tool_info = self.current_tool_calls.get(parent_id, {})
tool_name = tool_info.get("name", "unknown")
# Try stashed output first (from PostToolUse hook),
# then tool_use_result dict, then string content.
output = pop_pending_tool_output(tool_name)
if not output:
tur = sdk_message.tool_use_result
if tur is not None:
output = _extract_tool_use_result(tur)
if not output and isinstance(content, str) and content.strip():
output = content.strip()
if output:
responses.append(
StreamToolOutputAvailable(
toolCallId=parent_id,
toolName=tool_name,
output=output,
success=True,
)
)
resolved_in_blocks.add(parent_id)
self.resolved_tool_calls.update(resolved_in_blocks)
# Close the current step after tool results — the next # Close the current step after tool results — the next
# AssistantMessage will open a new step for the continuation. # AssistantMessage will open a new step for the continuation.
@@ -140,6 +178,7 @@ class SDKResponseAdapter:
self.step_open = False self.step_open = False
elif isinstance(sdk_message, ResultMessage): elif isinstance(sdk_message, ResultMessage):
self._flush_unresolved_tool_calls(responses)
self._end_text_if_open(responses) self._end_text_if_open(responses)
# Close the step before finishing. # Close the step before finishing.
if self.step_open: if self.step_open:
@@ -149,7 +188,7 @@ class SDKResponseAdapter:
if sdk_message.subtype == "success": if sdk_message.subtype == "success":
responses.append(StreamFinish()) responses.append(StreamFinish())
elif sdk_message.subtype in ("error", "error_during_execution"): elif sdk_message.subtype in ("error", "error_during_execution"):
error_msg = getattr(sdk_message, "result", None) or "Unknown error" error_msg = sdk_message.result or "Unknown error"
responses.append( responses.append(
StreamError(errorText=str(error_msg), code="sdk_error") StreamError(errorText=str(error_msg), code="sdk_error")
) )
@@ -180,6 +219,59 @@ class SDKResponseAdapter:
responses.append(StreamTextEnd(id=self.text_block_id)) responses.append(StreamTextEnd(id=self.text_block_id))
self.has_ended_text = True self.has_ended_text = True
def _flush_unresolved_tool_calls(self, responses: list[StreamBaseResponse]) -> None:
"""Emit outputs for tool calls that didn't receive a UserMessage result.
SDK built-in tools (WebSearch, Read, etc.) may be executed by the CLI
internally without surfacing a separate ``UserMessage`` with
``ToolResultBlock`` content. The ``PostToolUse`` hook stashes their
output, which we pop and emit here before the next ``AssistantMessage``
starts.
"""
flushed = False
for tool_id, tool_info in self.current_tool_calls.items():
if tool_id in self.resolved_tool_calls:
continue
tool_name = tool_info.get("name", "unknown")
output = pop_pending_tool_output(tool_name)
if output is not None:
responses.append(
StreamToolOutputAvailable(
toolCallId=tool_id,
toolName=tool_name,
output=output,
success=True,
)
)
self.resolved_tool_calls.add(tool_id)
flushed = True
logger.debug(
f"Flushed pending output for built-in tool {tool_name} "
f"(call {tool_id})"
)
else:
# No output available — emit an empty output so the frontend
# transitions the tool from input-available to output-available
# (stops the spinner).
responses.append(
StreamToolOutputAvailable(
toolCallId=tool_id,
toolName=tool_name,
output="",
success=True,
)
)
self.resolved_tool_calls.add(tool_id)
flushed = True
logger.debug(
f"Flushed empty output for unresolved tool {tool_name} "
f"(call {tool_id})"
)
if flushed and self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str: def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
"""Extract a string output from a ToolResultBlock's content field.""" """Extract a string output from a ToolResultBlock's content field."""
@@ -199,3 +291,30 @@ def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
return json.dumps(content) return json.dumps(content)
except (TypeError, ValueError): except (TypeError, ValueError):
return str(content) return str(content)
def _extract_tool_use_result(result: object) -> str:
"""Extract a string from a UserMessage's ``tool_use_result`` dict.
SDK built-in tools may store their result in ``tool_use_result``
instead of (or in addition to) ``ToolResultBlock`` content blocks.
"""
if isinstance(result, str):
return result
if isinstance(result, dict):
# Try common result keys
for key in ("content", "text", "output", "stdout", "result"):
val = result.get(key)
if isinstance(val, str) and val:
return val
# Fall back to JSON serialization of the whole dict
try:
return json.dumps(result)
except (TypeError, ValueError):
return str(result)
if result is None:
return ""
try:
return json.dumps(result)
except (TypeError, ValueError):
return str(result)

View File

@@ -16,6 +16,7 @@ from .tool_adapter import (
DANGEROUS_PATTERNS, DANGEROUS_PATTERNS,
MCP_TOOL_PREFIX, MCP_TOOL_PREFIX,
WORKSPACE_SCOPED_TOOLS, WORKSPACE_SCOPED_TOOLS,
stash_pending_tool_output,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -224,10 +225,25 @@ def create_security_hooks(
tool_use_id: str | None, tool_use_id: str | None,
context: HookContext, context: HookContext,
) -> SyncHookJSONOutput: ) -> SyncHookJSONOutput:
"""Log successful tool executions for observability.""" """Log successful tool executions and stash SDK built-in tool outputs.
MCP tools stash their output in ``_execute_tool_sync`` before the
SDK can truncate it. SDK built-in tools (WebSearch, Read, etc.)
are executed by the CLI internally — this hook captures their
output so the response adapter can forward it to the frontend.
"""
_ = context _ = context
tool_name = cast(str, input_data.get("tool_name", "")) tool_name = cast(str, input_data.get("tool_name", ""))
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}") logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
# Stash output for SDK built-in tools so the response adapter can
# emit StreamToolOutputAvailable even when the CLI doesn't surface
# a separate UserMessage with ToolResultBlock content.
if not tool_name.startswith(MCP_TOOL_PREFIX):
tool_response = input_data.get("tool_response")
if tool_response is not None:
stash_pending_tool_output(tool_name, tool_response)
return cast(SyncHookJSONOutput, {}) return cast(SyncHookJSONOutput, {})
async def post_tool_failure_hook( async def post_tool_failure_hook(

View File

@@ -47,6 +47,7 @@ from .tool_adapter import (
set_execution_context, set_execution_context,
) )
from .transcript import ( from .transcript import (
cleanup_cli_project_dir,
download_transcript, download_transcript,
read_transcript_file, read_transcript_file,
upload_transcript, upload_transcript,
@@ -86,9 +87,12 @@ _SDK_TOOL_SUPPLEMENT = """
for shell commands — it runs in a network-isolated sandbox. for shell commands — it runs in a network-isolated sandbox.
- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the - **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the
same working directory. Files created by one are readable by the other. same working directory. Files created by one are readable by the other.
These files are **ephemeral** — they exist only for the current session. - **IMPORTANT — File persistence**: Your working directory is **ephemeral** —
- **Persistent storage**: Use `write_workspace_file` / `read_workspace_file` files are lost between turns. When you create or modify important files
for files that should persist across sessions (stored in cloud storage). (code, configs, outputs), you MUST save them using `write_workspace_file`
so they persist. Use `read_workspace_file` and `list_workspace_files` to
access files saved in previous turns. If a "Files from previous turns"
section is present above, those files are available via `read_workspace_file`.
- Long-running tools (create_agent, edit_agent, etc.) are handled - Long-running tools (create_agent, edit_agent, etc.) are handled
asynchronously. You will receive an immediate response; the actual result asynchronously. You will receive an immediate response; the actual result
is delivered to the user via a background stream. is delivered to the user via a background stream.
@@ -268,48 +272,28 @@ def _make_sdk_cwd(session_id: str) -> str:
def _cleanup_sdk_tool_results(cwd: str) -> None: def _cleanup_sdk_tool_results(cwd: str) -> None:
"""Remove SDK tool-result files for a specific session working directory. """Remove SDK session artifacts for a specific working directory.
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/. Cleans up:
We clean only the specific cwd's results to avoid race conditions between - ``~/.claude/projects/<encoded-cwd>/`` — CLI session transcripts and
concurrent sessions. tool-result files. Each SDK turn uses a unique cwd, so this directory
is safe to remove entirely.
- ``/tmp/copilot-<session>/`` — the ephemeral working directory.
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id. Security: *cwd* MUST be created by ``_make_sdk_cwd()`` which sanitizes
the session_id.
""" """
import shutil import shutil
# Validate cwd is under the expected prefix
normalized = os.path.normpath(cwd) normalized = os.path.normpath(cwd)
if not normalized.startswith(_SDK_CWD_PREFIX): if not normalized.startswith(_SDK_CWD_PREFIX):
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}") logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
return return
# SDK encodes the cwd path by replacing '/' with '-' # Clean the CLI's project directory (transcripts + tool-results).
encoded_cwd = normalized.replace("/", "-") cleanup_cli_project_dir(cwd)
# Construct the project directory path (known-safe home expansion) # Clean up the temp cwd directory itself.
claude_projects = os.path.expanduser("~/.claude/projects")
project_dir = os.path.join(claude_projects, encoded_cwd)
# Security check 3: Validate project_dir is under ~/.claude/projects
project_dir = os.path.normpath(project_dir)
if not project_dir.startswith(claude_projects):
logger.warning(
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
)
return
results_dir = os.path.join(project_dir, "tool-results")
if os.path.isdir(results_dir):
for filename in os.listdir(results_dir):
file_path = os.path.join(results_dir, filename)
try:
if os.path.isfile(file_path):
os.remove(file_path)
except OSError:
pass
# Also clean up the temp cwd directory itself
try: try:
shutil.rmtree(normalized, ignore_errors=True) shutil.rmtree(normalized, ignore_errors=True)
except OSError: except OSError:
@@ -519,6 +503,7 @@ async def stream_chat_completion_sdk(
def _on_stop(transcript_path: str, sdk_session_id: str) -> None: def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
captured_transcript.path = transcript_path captured_transcript.path = transcript_path
captured_transcript.sdk_session_id = sdk_session_id captured_transcript.sdk_session_id = sdk_session_id
logger.debug(f"[SDK] Stop hook: path={transcript_path!r}")
security_hooks = create_security_hooks( security_hooks = create_security_hooks(
user_id, user_id,
@@ -530,18 +515,20 @@ async def stream_chat_completion_sdk(
# --- Resume strategy: download transcript from bucket --- # --- Resume strategy: download transcript from bucket ---
resume_file: str | None = None resume_file: str | None = None
use_resume = False use_resume = False
transcript_msg_count = 0 # watermark: session.messages length at upload
if config.claude_agent_use_resume and user_id and len(session.messages) > 1: if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
transcript_content = await download_transcript(user_id, session_id) dl = await download_transcript(user_id, session_id)
if transcript_content and validate_transcript(transcript_content): if dl and validate_transcript(dl.content):
resume_file = write_transcript_to_tempfile( resume_file = write_transcript_to_tempfile(
transcript_content, session_id, sdk_cwd dl.content, session_id, sdk_cwd
) )
if resume_file: if resume_file:
use_resume = True use_resume = True
logger.info( transcript_msg_count = dl.message_count
f"[SDK] Using --resume with transcript " logger.debug(
f"({len(transcript_content)} bytes)" f"[SDK] Using --resume ({len(dl.content)}B, "
f"msg_count={transcript_msg_count})"
) )
sdk_options_kwargs: dict[str, Any] = { sdk_options_kwargs: dict[str, Any] = {
@@ -582,11 +569,35 @@ async def stream_chat_completion_sdk(
# Build query: with --resume the CLI already has full # Build query: with --resume the CLI already has full
# context, so we only send the new message. Without # context, so we only send the new message. Without
# resume, compress history into a context prefix. # resume, compress history into a context prefix.
#
# Hybrid mode: if the transcript is stale (upload missed
# some turns), compress only the gap and prepend it so
# the agent has transcript context + missed turns.
query_message = current_message query_message = current_message
if not use_resume and len(session.messages) > 1: current_msg_count = len(session.messages)
if use_resume and transcript_msg_count >= 0:
# Transcript covers messages[0..M-1]. Current session
# has N messages (last one is the new user msg).
# Gap = messages[M .. N-2] (everything between upload
# and the current turn).
if transcript_msg_count < current_msg_count - 1:
gap = session.messages[transcript_msg_count:-1]
gap_context = _format_conversation_context(gap)
if gap_context:
logger.info(
f"[SDK] Transcript stale: covers {transcript_msg_count} "
f"of {current_msg_count} messages, compressing "
f"{len(gap)} missed messages"
)
query_message = (
f"{gap_context}\n\n"
f"Now, the user says:\n{current_message}"
)
elif not use_resume and current_msg_count > 1:
logger.warning( logger.warning(
f"[SDK] Using compression fallback for session " f"[SDK] Using compression fallback for session "
f"{session_id} ({len(session.messages)} messages) — " f"{session_id} ({current_msg_count} messages) — "
f"no transcript available for --resume" f"no transcript available for --resume"
) )
compressed = await _compress_conversation_history(session) compressed = await _compress_conversation_history(session)
@@ -597,10 +608,10 @@ async def stream_chat_completion_sdk(
f"Now, the user says:\n{current_message}" f"Now, the user says:\n{current_message}"
) )
logger.info( logger.debug(
f"[SDK] Sending query ({len(session.messages)} msgs in session)" f"[SDK] Sending query ({len(session.messages)} msgs, "
f"resume={use_resume})"
) )
logger.debug(f"[SDK] Query preview: {current_message[:80]!r}")
await client.query(query_message, session_id=session_id) await client.query(query_message, session_id=session_id)
assistant_response = ChatMessage(role="assistant", content="") assistant_response = ChatMessage(role="assistant", content="")
@@ -681,29 +692,33 @@ async def stream_chat_completion_sdk(
) and not has_appended_assistant: ) and not has_appended_assistant:
session.messages.append(assistant_response) session.messages.append(assistant_response)
# --- Capture transcript while CLI is still alive --- # --- Upload transcript for next-turn --resume ---
# Must happen INSIDE async with: close() sends SIGTERM # After async with the SDK task group has exited, so the Stop
# which kills the CLI before it can flush the JSONL. # hook has already fired and the CLI has been SIGTERMed. The
if ( # CLI uses appendFileSync, so all writes are safely on disk.
config.claude_agent_use_resume if config.claude_agent_use_resume and user_id:
and user_id # With --resume the CLI appends to the resume file (most
and captured_transcript.available # complete). Otherwise use the Stop hook path.
): if use_resume and resume_file:
# Give CLI time to flush JSONL writes before we read raw_transcript = read_transcript_file(resume_file)
await asyncio.sleep(0.5) elif captured_transcript.path:
raw_transcript = read_transcript_file(captured_transcript.path) raw_transcript = read_transcript_file(captured_transcript.path)
if raw_transcript:
try:
async with asyncio.timeout(30):
await _upload_transcript_bg(
user_id, session_id, raw_transcript
)
except asyncio.TimeoutError:
logger.warning(
f"[SDK] Transcript upload timed out for {session_id}"
)
else: else:
logger.debug("[SDK] Stop hook fired but transcript not usable") raw_transcript = None
if raw_transcript:
# Shield the upload from generator cancellation so a
# client disconnect / page refresh doesn't lose the
# transcript. The upload must finish even if the SSE
# connection is torn down.
await asyncio.shield(
_try_upload_transcript(
user_id,
session_id,
raw_transcript,
message_count=len(session.messages),
)
)
except ImportError: except ImportError:
raise RuntimeError( raise RuntimeError(
@@ -712,7 +727,7 @@ async def stream_chat_completion_sdk(
"to use the OpenAI-compatible fallback." "to use the OpenAI-compatible fallback."
) )
await upsert_chat_session(session) await asyncio.shield(upsert_chat_session(session))
logger.debug( logger.debug(
f"[SDK] Session {session_id} saved with {len(session.messages)} messages" f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
) )
@@ -722,7 +737,7 @@ async def stream_chat_completion_sdk(
except Exception as e: except Exception as e:
logger.error(f"[SDK] Error: {e}", exc_info=True) logger.error(f"[SDK] Error: {e}", exc_info=True)
try: try:
await upsert_chat_session(session) await asyncio.shield(upsert_chat_session(session))
except Exception as save_err: except Exception as save_err:
logger.error(f"[SDK] Failed to save session on error: {save_err}") logger.error(f"[SDK] Failed to save session on error: {save_err}")
yield StreamError( yield StreamError(
@@ -735,14 +750,31 @@ async def stream_chat_completion_sdk(
_cleanup_sdk_tool_results(sdk_cwd) _cleanup_sdk_tool_results(sdk_cwd)
async def _upload_transcript_bg( async def _try_upload_transcript(
user_id: str, session_id: str, raw_content: str user_id: str,
) -> None: session_id: str,
"""Background task to strip progress entries and upload transcript.""" raw_content: str,
message_count: int = 0,
) -> bool:
"""Strip progress entries and upload transcript (with timeout).
Returns True if the upload completed without error.
"""
try: try:
await upload_transcript(user_id, session_id, raw_content) async with asyncio.timeout(30):
await upload_transcript(
user_id, session_id, raw_content, message_count=message_count
)
return True
except asyncio.TimeoutError:
logger.warning(f"[SDK] Transcript upload timed out for {session_id}")
return False
except Exception as e: except Exception as e:
logger.error(f"[SDK] Failed to upload transcript for {session_id}: {e}") logger.error(
f"[SDK] Failed to upload transcript for {session_id}: {e}",
exc_info=True,
)
return False
async def _update_title_async( async def _update_title_async(

View File

@@ -41,7 +41,7 @@ _current_session: ContextVar[ChatSession | None] = ContextVar(
# Stash for MCP tool outputs before the SDK potentially truncates them. # Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the # Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable. # response adapter when it builds StreamToolOutputAvailable.
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar( _pending_tool_outputs: ContextVar[dict[str, list[str]]] = ContextVar(
"pending_tool_outputs", default=None # type: ignore[arg-type] "pending_tool_outputs", default=None # type: ignore[arg-type]
) )
@@ -88,19 +88,52 @@ def get_execution_context() -> tuple[str | None, ChatSession | None]:
def pop_pending_tool_output(tool_name: str) -> str | None: def pop_pending_tool_output(tool_name: str) -> str | None:
"""Pop and return the stashed full output for *tool_name*. """Pop and return the oldest stashed output for *tool_name*.
The SDK CLI may truncate large tool results (writing them to disk and The SDK CLI may truncate large tool results (writing them to disk and
replacing the content with a file reference). This stash keeps the replacing the content with a file reference). This stash keeps the
original MCP output so the response adapter can forward it to the original MCP output so the response adapter can forward it to the
frontend for proper widget rendering. frontend for proper widget rendering.
Uses a FIFO queue per tool name so duplicate calls to the same tool
in one turn each get their own output.
Returns ``None`` if nothing was stashed for *tool_name*. Returns ``None`` if nothing was stashed for *tool_name*.
""" """
pending = _pending_tool_outputs.get(None) pending = _pending_tool_outputs.get(None)
if pending is None: if pending is None:
return None return None
return pending.pop(tool_name, None) queue = pending.get(tool_name)
if not queue:
pending.pop(tool_name, None)
return None
value = queue.pop(0)
if not queue:
del pending[tool_name]
return value
def stash_pending_tool_output(tool_name: str, output: Any) -> None:
"""Stash tool output for later retrieval by the response adapter.
Used by the PostToolUse hook to capture SDK built-in tool outputs
(WebSearch, Read, etc.) that aren't available through the MCP stash
mechanism in ``_execute_tool_sync``.
Appends to a FIFO queue per tool name so multiple calls to the same
tool in one turn are all preserved.
"""
pending = _pending_tool_outputs.get(None)
if pending is None:
return
if isinstance(output, str):
text = output
else:
try:
text = json.dumps(output)
except (TypeError, ValueError):
text = str(output)
pending.setdefault(tool_name, []).append(text)
async def _execute_tool_sync( async def _execute_tool_sync(
@@ -125,14 +158,63 @@ async def _execute_tool_sync(
# Stash the full output before the SDK potentially truncates it. # Stash the full output before the SDK potentially truncates it.
pending = _pending_tool_outputs.get(None) pending = _pending_tool_outputs.get(None)
if pending is not None: if pending is not None:
pending[base_tool.name] = text pending.setdefault(base_tool.name, []).append(text)
content_blocks: list[dict[str, str]] = [{"type": "text", "text": text}]
# If the tool result contains inline image data, add an MCP image block
# so Claude can "see" the image (e.g. read_workspace_file on a small PNG).
image_block = _extract_image_block(text)
if image_block:
content_blocks.append(image_block)
return { return {
"content": [{"type": "text", "text": text}], "content": content_blocks,
"isError": not result.success, "isError": not result.success,
} }
# MIME types that Claude can process as image content blocks.
_SUPPORTED_IMAGE_TYPES = frozenset(
{"image/png", "image/jpeg", "image/gif", "image/webp"}
)
def _extract_image_block(text: str) -> dict[str, str] | None:
"""Extract an MCP image content block from a tool result JSON string.
Detects workspace file responses with ``content_base64`` and an image
MIME type, returning an MCP-format image block that allows Claude to
"see" the image. Returns ``None`` if the result is not an inline image.
"""
try:
data = json.loads(text)
except (json.JSONDecodeError, TypeError):
return None
if not isinstance(data, dict):
return None
mime_type = data.get("mime_type", "")
base64_content = data.get("content_base64", "")
# Only inline small images — large ones would exceed Claude's limits.
# 32 KB raw ≈ ~43 KB base64.
_MAX_IMAGE_BASE64_BYTES = 43_000
if (
mime_type in _SUPPORTED_IMAGE_TYPES
and base64_content
and len(base64_content) <= _MAX_IMAGE_BASE64_BYTES
):
return {
"type": "image",
"data": base64_content,
"mimeType": mime_type,
}
return None
def _mcp_error(message: str) -> dict[str, Any]: def _mcp_error(message: str) -> dict[str, Any]:
return { return {
"content": [ "content": [
@@ -311,14 +393,29 @@ def create_copilot_mcp_server():
# which provides kernel-level network isolation via unshare --net. # which provides kernel-level network isolation via unshare --net.
# Task allows spawning sub-agents (rate-limited by security hooks). # Task allows spawning sub-agents (rate-limited by security hooks).
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk. # WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task", "WebSearch"] # TodoWrite manages the task checklist shown in the UI — no security concern.
_SDK_BUILTIN_TOOLS = [
"Read",
"Write",
"Edit",
"Glob",
"Grep",
"Task",
"WebSearch",
"TodoWrite",
]
# SDK built-in tools that must be explicitly blocked. # SDK built-in tools that must be explicitly blocked.
# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level # Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level
# network isolation (unshare --net) instead. # network isolation (unshare --net) instead.
# WebFetch: SSRF risk — can reach internal network (localhost, 10.x, etc.). # WebFetch: SSRF risk — can reach internal network (localhost, 10.x, etc.).
# Agent uses the SSRF-protected mcp__copilot__web_fetch tool instead. # Agent uses the SSRF-protected mcp__copilot__web_fetch tool instead.
SDK_DISALLOWED_TOOLS = ["Bash", "WebFetch"] # AskUserQuestion: interactive CLI tool — no terminal in copilot context.
SDK_DISALLOWED_TOOLS = [
"Bash",
"WebFetch",
"AskUserQuestion",
]
# Tools that are blocked entirely in security hooks (defence-in-depth). # Tools that are blocked entirely in security hooks (defence-in-depth).
# Includes SDK_DISALLOWED_TOOLS plus common aliases/synonyms. # Includes SDK_DISALLOWED_TOOLS plus common aliases/synonyms.

View File

@@ -14,6 +14,8 @@ import json
import logging import logging
import os import os
import re import re
import time
from dataclasses import dataclass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -31,6 +33,16 @@ STRIPPABLE_TYPES = frozenset(
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"} {"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
) )
@dataclass
class TranscriptDownload:
"""Result of downloading a transcript with its metadata."""
content: str
message_count: int = 0 # session.messages length when uploaded
uploaded_at: float = 0.0 # epoch timestamp of upload
# Workspace storage constants — deterministic path from session_id. # Workspace storage constants — deterministic path from session_id.
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts" TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
@@ -119,23 +131,19 @@ def read_transcript_file(transcript_path: str) -> str | None:
content = f.read() content = f.read()
if not content.strip(): if not content.strip():
logger.debug(f"[Transcript] Empty file: {transcript_path}")
return None return None
lines = content.strip().split("\n") lines = content.strip().split("\n")
if len(lines) < 3: if len(lines) < 3:
# Raw files with ≤2 lines are metadata-only # Raw files with ≤2 lines are metadata-only
# (queue-operation + file-history-snapshot, no conversation). # (queue-operation + file-history-snapshot, no conversation).
logger.debug(
f"[Transcript] Too few lines ({len(lines)}): {transcript_path}"
)
return None return None
# Quick structural validation — parse first and last lines. # Quick structural validation — parse first and last lines.
json.loads(lines[0]) json.loads(lines[0])
json.loads(lines[-1]) json.loads(lines[-1])
logger.info( logger.debug(
f"[Transcript] Read {len(lines)} lines, " f"[Transcript] Read {len(lines)} lines, "
f"{len(content)} bytes from {transcript_path}" f"{len(content)} bytes from {transcript_path}"
) )
@@ -160,6 +168,41 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-") _SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does.
The CLI replaces all non-alphanumeric characters with ``-``.
"""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory.
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
safe to remove entirely after the transcript has been uploaded.
"""
import shutil
cwd_encoded = _encode_cwd_for_cli(sdk_cwd)
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
if not project_dir.startswith(projects_base + os.sep):
logger.warning(
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
)
return
if os.path.isdir(project_dir):
shutil.rmtree(project_dir, ignore_errors=True)
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
else:
logger.debug(f"[Transcript] Project dir not found: {project_dir}")
def write_transcript_to_tempfile( def write_transcript_to_tempfile(
transcript_content: str, transcript_content: str,
session_id: str, session_id: str,
@@ -191,7 +234,7 @@ def write_transcript_to_tempfile(
with open(jsonl_path, "w") as f: with open(jsonl_path, "w") as f:
f.write(transcript_content) f.write(transcript_content)
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}") logger.debug(f"[Transcript] Wrote resume file: {jsonl_path}")
return jsonl_path return jsonl_path
except OSError as e: except OSError as e:
@@ -248,6 +291,15 @@ def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
) )
def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
"""Return (workspace_id, file_id, filename) for a session's transcript metadata."""
return (
TRANSCRIPT_STORAGE_PREFIX,
_sanitize_id(user_id),
f"{_sanitize_id(session_id)}.meta.json",
)
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str: def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
"""Build the full storage path string that ``retrieve()`` expects. """Build the full storage path string that ``retrieve()`` expects.
@@ -268,21 +320,30 @@ def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
return f"local://{wid}/{fid}/{fname}" return f"local://{wid}/{fid}/{fname}"
async def upload_transcript(user_id: str, session_id: str, content: str) -> None: async def upload_transcript(
user_id: str,
session_id: str,
content: str,
message_count: int = 0,
) -> None:
"""Strip progress entries and upload transcript to bucket storage. """Strip progress entries and upload transcript to bucket storage.
Safety: only overwrites when the new (stripped) transcript is larger than Safety: only overwrites when the new (stripped) transcript is larger than
what is already stored. Since JSONL is append-only, the latest transcript what is already stored. Since JSONL is append-only, the latest transcript
is always the longest. This prevents a slow/stale background task from is always the longest. This prevents a slow/stale background task from
clobbering a newer upload from a concurrent turn. clobbering a newer upload from a concurrent turn.
Args:
message_count: ``len(session.messages)`` at upload time — used by
the next turn to detect staleness and compress only the gap.
""" """
from backend.util.workspace_storage import get_workspace_storage from backend.util.workspace_storage import get_workspace_storage
stripped = strip_progress_entries(content) stripped = strip_progress_entries(content)
if not validate_transcript(stripped): if not validate_transcript(stripped):
logger.warning( logger.warning(
f"[Transcript] Skipping upload — stripped content is not a valid " f"[Transcript] Skipping upload — stripped content not valid "
f"transcript for session {session_id}" f"for session {session_id}"
) )
return return
@@ -296,10 +357,9 @@ async def upload_transcript(user_id: str, session_id: str, content: str) -> None
try: try:
existing = await storage.retrieve(path) existing = await storage.retrieve(path)
if len(existing) >= new_size: if len(existing) >= new_size:
logger.info( logger.debug(
f"[Transcript] Skipping upload — existing transcript " f"[Transcript] Skipping upload — existing ({len(existing)}B) "
f"({len(existing)}B) >= new ({new_size}B) for session " f">= new ({new_size}B) for session {session_id}"
f"{session_id}"
) )
return return
except (FileNotFoundError, Exception): except (FileNotFoundError, Exception):
@@ -311,16 +371,38 @@ async def upload_transcript(user_id: str, session_id: str, content: str) -> None
filename=fname, filename=fname,
content=encoded, content=encoded,
) )
# Store metadata alongside the transcript so the next turn can detect
# staleness and only compress the gap instead of the full history.
# Wrapped in try/except so a metadata write failure doesn't orphan
# the already-uploaded transcript — the next turn will just fall back
# to full gap fill (msg_count=0).
try:
meta = {"message_count": message_count, "uploaded_at": time.time()}
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
await storage.store(
workspace_id=mwid,
file_id=mfid,
filename=mfname,
content=json.dumps(meta).encode("utf-8"),
)
except Exception as e:
logger.warning(f"[Transcript] Failed to write metadata for {session_id}: {e}")
logger.info( logger.info(
f"[Transcript] Uploaded {new_size} bytes " f"[Transcript] Uploaded {new_size}B "
f"(stripped from {len(content)}) for session {session_id}" f"(stripped from {len(content)}B, msg_count={message_count}) "
f"for session {session_id}"
) )
async def download_transcript(user_id: str, session_id: str) -> str | None: async def download_transcript(
"""Download transcript from bucket storage. user_id: str, session_id: str
) -> TranscriptDownload | None:
"""Download transcript and metadata from bucket storage.
Returns the JSONL content string, or ``None`` if not found. Returns a ``TranscriptDownload`` with the JSONL content and the
``message_count`` watermark from the upload, or ``None`` if not found.
""" """
from backend.util.workspace_storage import get_workspace_storage from backend.util.workspace_storage import get_workspace_storage
@@ -330,10 +412,6 @@ async def download_transcript(user_id: str, session_id: str) -> str | None:
try: try:
data = await storage.retrieve(path) data = await storage.retrieve(path)
content = data.decode("utf-8") content = data.decode("utf-8")
logger.info(
f"[Transcript] Downloaded {len(content)} bytes for session {session_id}"
)
return content
except FileNotFoundError: except FileNotFoundError:
logger.debug(f"[Transcript] No transcript in storage for {session_id}") logger.debug(f"[Transcript] No transcript in storage for {session_id}")
return None return None
@@ -341,6 +419,36 @@ async def download_transcript(user_id: str, session_id: str) -> str | None:
logger.warning(f"[Transcript] Failed to download transcript: {e}") logger.warning(f"[Transcript] Failed to download transcript: {e}")
return None return None
# Try to load metadata (best-effort — old transcripts won't have it)
message_count = 0
uploaded_at = 0.0
try:
from backend.util.workspace_storage import GCSWorkspaceStorage
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
if isinstance(storage, GCSWorkspaceStorage):
blob = f"workspaces/{mwid}/{mfid}/{mfname}"
meta_path = f"gcs://{storage.bucket_name}/{blob}"
else:
meta_path = f"local://{mwid}/{mfid}/{mfname}"
meta_data = await storage.retrieve(meta_path)
meta = json.loads(meta_data.decode("utf-8"))
message_count = meta.get("message_count", 0)
uploaded_at = meta.get("uploaded_at", 0.0)
except (FileNotFoundError, json.JSONDecodeError, Exception):
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
logger.debug(
f"[Transcript] Downloaded {len(content)}B "
f"(msg_count={message_count}) for session {session_id}"
)
return TranscriptDownload(
content=content,
message_count=message_count,
uploaded_at=uploaded_at,
)
async def delete_transcript(user_id: str, session_id: str) -> None: async def delete_transcript(user_id: str, session_id: str) -> None:
"""Delete transcript from bucket storage (e.g. after resume failure).""" """Delete transcript from bucket storage (e.g. after resume failure)."""

View File

@@ -387,7 +387,7 @@ async def stream_chat_completion(
if user_id: if user_id:
log_meta["user_id"] = user_id log_meta["user_id"] = user_id
logger.info( logger.debug(
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, " f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
f"message_len={len(message) if message else 0}, is_user={is_user_message}", f"message_len={len(message) if message else 0}, is_user={is_user_message}",
extra={ extra={
@@ -404,7 +404,7 @@ async def stream_chat_completion(
fetch_start = time.monotonic() fetch_start = time.monotonic()
session = await get_chat_session(session_id, user_id) session = await get_chat_session(session_id, user_id)
fetch_time = (time.monotonic() - fetch_start) * 1000 fetch_time = (time.monotonic() - fetch_start) * 1000
logger.info( logger.debug(
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, " f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
f"n_messages={len(session.messages) if session else 0}", f"n_messages={len(session.messages) if session else 0}",
extra={ extra={
@@ -416,7 +416,7 @@ async def stream_chat_completion(
}, },
) )
else: else:
logger.info( logger.debug(
f"[TIMING] Using provided session, messages={len(session.messages)}", f"[TIMING] Using provided session, messages={len(session.messages)}",
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}}, extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
) )
@@ -450,7 +450,7 @@ async def stream_chat_completion(
message_length=len(message), message_length=len(message),
) )
posthog_time = (time.monotonic() - posthog_start) * 1000 posthog_time = (time.monotonic() - posthog_start) * 1000
logger.info( logger.debug(
f"[TIMING] track_user_message took {posthog_time:.1f}ms", f"[TIMING] track_user_message took {posthog_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}}, extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
) )
@@ -458,7 +458,7 @@ async def stream_chat_completion(
upsert_start = time.monotonic() upsert_start = time.monotonic()
session = await upsert_chat_session(session) session = await upsert_chat_session(session)
upsert_time = (time.monotonic() - upsert_start) * 1000 upsert_time = (time.monotonic() - upsert_start) * 1000
logger.info( logger.debug(
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms", f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}}, extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
) )
@@ -503,7 +503,7 @@ async def stream_chat_completion(
prompt_start = time.monotonic() prompt_start = time.monotonic()
system_prompt, understanding = await _build_system_prompt(user_id) system_prompt, understanding = await _build_system_prompt(user_id)
prompt_time = (time.monotonic() - prompt_start) * 1000 prompt_time = (time.monotonic() - prompt_start) * 1000
logger.info( logger.debug(
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms", f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}}, extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
) )
@@ -537,7 +537,7 @@ async def stream_chat_completion(
# Only yield message start for the initial call, not for continuations. # Only yield message start for the initial call, not for continuations.
setup_time = (time.monotonic() - completion_start) * 1000 setup_time = (time.monotonic() - completion_start) * 1000
logger.info( logger.debug(
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms", f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}}, extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
) )
@@ -548,7 +548,7 @@ async def stream_chat_completion(
yield StreamStartStep() yield StreamStartStep()
try: try:
logger.info( logger.debug(
"[TIMING] Calling _stream_chat_chunks", "[TIMING] Calling _stream_chat_chunks",
extra={"json_fields": log_meta}, extra={"json_fields": log_meta},
) )
@@ -988,7 +988,7 @@ async def _stream_chat_chunks(
if session.user_id: if session.user_id:
log_meta["user_id"] = session.user_id log_meta["user_id"] = session.user_id
logger.info( logger.debug(
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, " f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
f"user={session.user_id}, n_messages={len(session.messages)}", f"user={session.user_id}, n_messages={len(session.messages)}",
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}}, extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
@@ -1011,7 +1011,7 @@ async def _stream_chat_chunks(
base_url=config.base_url, base_url=config.base_url,
) )
context_time = (time_module.perf_counter() - context_start) * 1000 context_time = (time_module.perf_counter() - context_start) * 1000
logger.info( logger.debug(
f"[TIMING] _manage_context_window took {context_time:.1f}ms", f"[TIMING] _manage_context_window took {context_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": context_time}}, extra={"json_fields": {**log_meta, "duration_ms": context_time}},
) )
@@ -1053,7 +1053,7 @@ async def _stream_chat_chunks(
retry_info = ( retry_info = (
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else "" f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
) )
logger.info( logger.debug(
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}", f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
extra={ extra={
"json_fields": { "json_fields": {
@@ -1093,7 +1093,7 @@ async def _stream_chat_chunks(
extra_body=extra_body, extra_body=extra_body,
) )
api_init_time = (time_module.perf_counter() - api_call_start) * 1000 api_init_time = (time_module.perf_counter() - api_call_start) * 1000
logger.info( logger.debug(
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms", f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}}, extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
) )
@@ -1142,7 +1142,7 @@ async def _stream_chat_chunks(
ttfc = ( ttfc = (
time_module.perf_counter() - api_call_start time_module.perf_counter() - api_call_start
) * 1000 ) * 1000
logger.info( logger.debug(
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms " f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
f"(since API call), n_chunks={chunk_count}", f"(since API call), n_chunks={chunk_count}",
extra={ extra={
@@ -1210,7 +1210,7 @@ async def _stream_chat_chunks(
) )
emitted_start_for_idx.add(idx) emitted_start_for_idx.add(idx)
stream_duration = time_module.perf_counter() - api_call_start stream_duration = time_module.perf_counter() - api_call_start
logger.info( logger.debug(
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, " f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
f"duration={stream_duration:.2f}s, " f"duration={stream_duration:.2f}s, "
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}", f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
@@ -1244,7 +1244,7 @@ async def _stream_chat_chunks(
raise raise
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000 total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
logger.info( logger.debug(
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; " f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
f"session={session.session_id}, user={session.user_id}", f"session={session.session_id}, user={session.user_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}}, extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
@@ -1494,8 +1494,8 @@ async def _yield_tool_call(
# Mark stream registry task as failed if it was created # Mark stream registry task as failed if it was created
try: try:
await stream_registry.mark_task_completed(task_id, status="failed") await stream_registry.mark_task_completed(task_id, status="failed")
except Exception: except Exception as mark_err:
pass logger.warning(f"Failed to mark task {task_id} as failed: {mark_err}")
logger.error( logger.error(
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
) )

View File

@@ -143,7 +143,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
"Transcript was not uploaded to bucket after turn 1 — " "Transcript was not uploaded to bucket after turn 1 — "
"Stop hook may not have fired or transcript was too small" "Stop hook may not have fired or transcript was too small"
) )
logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes") logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes")
# Reload session for turn 2 # Reload session for turn 2
session = await get_chat_session(session.session_id, test_user_id) session = await get_chat_session(session.session_id, test_user_id)

View File

@@ -117,7 +117,7 @@ async def create_task(
if user_id: if user_id:
log_meta["user_id"] = user_id log_meta["user_id"] = user_id
logger.info( logger.debug(
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}", f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
extra={"json_fields": log_meta}, extra={"json_fields": log_meta},
) )
@@ -135,7 +135,7 @@ async def create_task(
redis_start = time.perf_counter() redis_start = time.perf_counter()
redis = await get_redis_async() redis = await get_redis_async()
redis_time = (time.perf_counter() - redis_start) * 1000 redis_time = (time.perf_counter() - redis_start) * 1000
logger.info( logger.debug(
f"[TIMING] get_redis_async took {redis_time:.1f}ms", f"[TIMING] get_redis_async took {redis_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": redis_time}}, extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
) )
@@ -158,7 +158,7 @@ async def create_task(
}, },
) )
hset_time = (time.perf_counter() - hset_start) * 1000 hset_time = (time.perf_counter() - hset_start) * 1000
logger.info( logger.debug(
f"[TIMING] redis.hset took {hset_time:.1f}ms", f"[TIMING] redis.hset took {hset_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": hset_time}}, extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
) )
@@ -169,7 +169,7 @@ async def create_task(
await redis.set(op_key, task_id, ex=config.stream_ttl) await redis.set(op_key, task_id, ex=config.stream_ttl)
total_time = (time.perf_counter() - start_time) * 1000 total_time = (time.perf_counter() - start_time) * 1000
logger.info( logger.debug(
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}", f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}}, extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
) )
@@ -230,7 +230,7 @@ async def publish_chunk(
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd") in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
or total_time > 50 or total_time > 50
): ):
logger.info( logger.debug(
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)", f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
extra={ extra={
"json_fields": { "json_fields": {
@@ -279,7 +279,7 @@ async def subscribe_to_task(
if user_id: if user_id:
log_meta["user_id"] = user_id log_meta["user_id"] = user_id
logger.info( logger.debug(
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}", f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}}, extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
) )
@@ -289,14 +289,14 @@ async def subscribe_to_task(
meta_key = _get_task_meta_key(task_id) meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
hgetall_time = (time.perf_counter() - redis_start) * 1000 hgetall_time = (time.perf_counter() - redis_start) * 1000
logger.info( logger.debug(
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms", f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}}, extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
) )
if not meta: if not meta:
elapsed = (time.perf_counter() - start_time) * 1000 elapsed = (time.perf_counter() - start_time) * 1000
logger.info( logger.debug(
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms", f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
extra={ extra={
"json_fields": { "json_fields": {
@@ -335,7 +335,7 @@ async def subscribe_to_task(
xread_start = time.perf_counter() xread_start = time.perf_counter()
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000) messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
xread_time = (time.perf_counter() - xread_start) * 1000 xread_time = (time.perf_counter() - xread_start) * 1000
logger.info( logger.debug(
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}", f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
extra={ extra={
"json_fields": { "json_fields": {
@@ -363,7 +363,7 @@ async def subscribe_to_task(
except Exception as e: except Exception as e:
logger.warning(f"Failed to replay message: {e}") logger.warning(f"Failed to replay message: {e}")
logger.info( logger.debug(
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}", f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
extra={ extra={
"json_fields": { "json_fields": {
@@ -376,7 +376,7 @@ async def subscribe_to_task(
# Step 2: If task is still running, start stream listener for live updates # Step 2: If task is still running, start stream listener for live updates
if task_status == "running": if task_status == "running":
logger.info( logger.debug(
"[TIMING] Task still running, starting _stream_listener", "[TIMING] Task still running, starting _stream_listener",
extra={"json_fields": {**log_meta, "task_status": task_status}}, extra={"json_fields": {**log_meta, "task_status": task_status}},
) )
@@ -387,14 +387,14 @@ async def subscribe_to_task(
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task) _listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
else: else:
# Task is completed/failed - add finish marker # Task is completed/failed - add finish marker
logger.info( logger.debug(
f"[TIMING] Task already {task_status}, adding StreamFinish", f"[TIMING] Task already {task_status}, adding StreamFinish",
extra={"json_fields": {**log_meta, "task_status": task_status}}, extra={"json_fields": {**log_meta, "task_status": task_status}},
) )
await subscriber_queue.put(StreamFinish()) await subscriber_queue.put(StreamFinish())
total_time = (time.perf_counter() - start_time) * 1000 total_time = (time.perf_counter() - start_time) * 1000
logger.info( logger.debug(
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, " f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
f"n_messages_replayed={replayed_count}", f"n_messages_replayed={replayed_count}",
extra={ extra={
@@ -433,7 +433,7 @@ async def _stream_listener(
if log_meta is None: if log_meta is None:
log_meta = {"component": "StreamRegistry", "task_id": task_id} log_meta = {"component": "StreamRegistry", "task_id": task_id}
logger.info( logger.debug(
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}", f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}}, extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
) )
@@ -462,7 +462,7 @@ async def _stream_listener(
if messages: if messages:
msg_count = sum(len(msgs) for _, msgs in messages) msg_count = sum(len(msgs) for _, msgs in messages)
logger.info( logger.debug(
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms", f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
extra={ extra={
"json_fields": { "json_fields": {
@@ -475,7 +475,7 @@ async def _stream_listener(
) )
elif xread_time > 1000: elif xread_time > 1000:
# Only log timeouts (30s blocking) # Only log timeouts (30s blocking)
logger.info( logger.debug(
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms", f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
extra={ extra={
"json_fields": { "json_fields": {
@@ -526,7 +526,7 @@ async def _stream_listener(
if first_message_time is None: if first_message_time is None:
first_message_time = time.perf_counter() first_message_time = time.perf_counter()
elapsed = (first_message_time - start_time) * 1000 elapsed = (first_message_time - start_time) * 1000
logger.info( logger.debug(
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}", f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
extra={ extra={
"json_fields": { "json_fields": {
@@ -568,7 +568,7 @@ async def _stream_listener(
# Stop listening on finish # Stop listening on finish
if isinstance(chunk, StreamFinish): if isinstance(chunk, StreamFinish):
total_time = (time.perf_counter() - start_time) * 1000 total_time = (time.perf_counter() - start_time) * 1000
logger.info( logger.debug(
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}", f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
extra={ extra={
"json_fields": { "json_fields": {
@@ -587,7 +587,7 @@ async def _stream_listener(
except asyncio.CancelledError: except asyncio.CancelledError:
elapsed = (time.perf_counter() - start_time) * 1000 elapsed = (time.perf_counter() - start_time) * 1000
logger.info( logger.debug(
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}", f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
extra={ extra={
"json_fields": { "json_fields": {
@@ -619,7 +619,7 @@ async def _stream_listener(
finally: finally:
# Clean up listener task mapping on exit # Clean up listener task mapping on exit
total_time = (time.perf_counter() - start_time) * 1000 total_time = (time.perf_counter() - start_time) * 1000
logger.info( logger.debug(
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, " f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
f"delivered={messages_delivered}, xread_count={xread_count}", f"delivered={messages_delivered}, xread_count={xread_count}",
extra={ extra={
@@ -829,10 +829,13 @@ async def get_active_task_for_session(
) )
await mark_task_completed(task_id, "failed") await mark_task_completed(task_id, "failed")
continue continue
except (ValueError, TypeError): except (ValueError, TypeError) as exc:
pass logger.warning(
f"[TASK_LOOKUP] Failed to parse created_at "
f"for task {task_id[:8]}...: {exc}"
)
logger.info( logger.debug(
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..." f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
) )

View File

@@ -312,8 +312,18 @@ class ReadWorkspaceFileTool(BaseTool):
is_small_file = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES is_small_file = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
is_text_file = self._is_text_mime_type(file_info.mime_type) is_text_file = self._is_text_mime_type(file_info.mime_type)
# Return inline content for small text files (unless force_download_url) # Return inline content for small text/image files (unless force_download_url)
if is_small_file and is_text_file and not force_download_url: is_image_file = file_info.mime_type in {
"image/png",
"image/jpeg",
"image/gif",
"image/webp",
}
if (
is_small_file
and (is_text_file or is_image_file)
and not force_download_url
):
content = await manager.read_file_by_id(target_file_id) content = await manager.read_file_by_id(target_file_id)
content_b64 = base64.b64encode(content).decode("utf-8") content_b64 = base64.b64encode(content).decode("utf-8")

View File

@@ -599,6 +599,15 @@ def get_service_client(
if error_response and error_response.type in EXCEPTION_MAPPING: if error_response and error_response.type in EXCEPTION_MAPPING:
exception_class = EXCEPTION_MAPPING[error_response.type] exception_class = EXCEPTION_MAPPING[error_response.type]
args = error_response.args or [str(e)] args = error_response.args or [str(e)]
# Prisma DataError subclasses expect a dict `data` arg,
# but RPC serialization only preserves the string message
# from exc.args. Wrap it in the expected structure so
# the constructor doesn't crash on `.get()`.
if issubclass(exception_class, DataError):
msg = str(args[0]) if args else str(e)
raise exception_class({"user_facing_error": {"message": msg}})
raise exception_class(*args) raise exception_class(*args)
# Otherwise categorize by HTTP status code # Otherwise categorize by HTTP status code

View File

@@ -6,6 +6,7 @@ from unittest.mock import Mock
import httpx import httpx
import pytest import pytest
from prisma.errors import DataError, UniqueViolationError
from backend.util.service import ( from backend.util.service import (
AppService, AppService,
@@ -447,6 +448,39 @@ class TestHTTPErrorRetryBehavior:
assert "Invalid parameter value" in str(exc_info.value) assert "Invalid parameter value" in str(exc_info.value)
def test_prisma_data_error_reconstructed_correctly(self):
"""Test that DataError subclasses (e.g. UniqueViolationError) are
reconstructed without crashing.
Prisma's DataError.__init__ expects a dict `data` arg with
a 'user_facing_error' key. RPC serialization only preserves the
string message via exc.args, so the client must wrap it in the
expected dict structure.
"""
for exc_type in [DataError, UniqueViolationError]:
mock_response = Mock()
mock_response.status_code = 400
mock_response.json.return_value = {
"type": exc_type.__name__,
"args": ["Unique constraint failed on the fields: (`path`)"],
}
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"400 Bad Request", request=Mock(), response=mock_response
)
client = get_service_client(ServiceTestClient)
with pytest.raises(exc_type) as exc_info:
client._handle_call_method_response( # type: ignore[attr-defined]
response=mock_response, method_name="test_method"
)
# The exception should have the message preserved
assert "Unique constraint" in str(exc_info.value)
# And should have the expected data structure (not crash)
assert hasattr(exc_info.value, "data")
assert isinstance(exc_info.value.data, dict)
def test_client_error_status_codes_coverage(self): def test_client_error_status_codes_coverage(self):
"""Test that various 4xx status codes are all wrapped as HTTPClientError.""" """Test that various 4xx status codes are all wrapped as HTTPClientError."""
client_error_codes = [400, 401, 403, 404, 405, 409, 422, 429] client_error_codes = [400, 401, 403, 404, 405, 409, 422, 429]

View File

@@ -23,6 +23,7 @@ export function CopilotPage() {
status, status,
error, error,
stop, stop,
isReconnecting,
createSession, createSession,
onSend, onSend,
isLoadingSession, isLoadingSession,
@@ -71,6 +72,7 @@ export function CopilotPage() {
sessionId={sessionId} sessionId={sessionId}
isLoadingSession={isLoadingSession} isLoadingSession={isLoadingSession}
isCreatingSession={isCreatingSession} isCreatingSession={isCreatingSession}
isReconnecting={isReconnecting}
onCreateSession={createSession} onCreateSession={createSession}
onSend={onSend} onSend={onSend}
onStop={stop} onStop={stop}

View File

@@ -14,6 +14,8 @@ export interface ChatContainerProps {
sessionId: string | null; sessionId: string | null;
isLoadingSession: boolean; isLoadingSession: boolean;
isCreatingSession: boolean; isCreatingSession: boolean;
/** True when backend has an active stream but we haven't reconnected yet. */
isReconnecting?: boolean;
onCreateSession: () => void | Promise<string>; onCreateSession: () => void | Promise<string>;
onSend: (message: string) => void | Promise<void>; onSend: (message: string) => void | Promise<void>;
onStop: () => void; onStop: () => void;
@@ -26,11 +28,13 @@ export const ChatContainer = ({
sessionId, sessionId,
isLoadingSession, isLoadingSession,
isCreatingSession, isCreatingSession,
isReconnecting,
onCreateSession, onCreateSession,
onSend, onSend,
onStop, onStop,
headerSlot, headerSlot,
}: ChatContainerProps) => { }: ChatContainerProps) => {
const isBusy = status === "streaming" || !!isReconnecting;
const inputLayoutId = "copilot-2-chat-input"; const inputLayoutId = "copilot-2-chat-input";
return ( return (
@@ -56,8 +60,8 @@ export const ChatContainer = ({
<ChatInput <ChatInput
inputId="chat-input-session" inputId="chat-input-session"
onSend={onSend} onSend={onSend}
disabled={status === "streaming"} disabled={isBusy}
isStreaming={status === "streaming"} isStreaming={isBusy}
onStop={onStop} onStop={onStop}
placeholder="What else can I help with?" placeholder="What else can I help with?"
/> />

View File

@@ -1,63 +1,713 @@
"use client"; "use client";
import React from "react";
import { ToolUIPart } from "ai"; import { ToolUIPart } from "ai";
import { GearIcon } from "@phosphor-icons/react"; import {
CheckCircleIcon,
CircleDashedIcon,
CircleIcon,
FileIcon,
FilesIcon,
GearIcon,
GlobeIcon,
ListChecksIcon,
MagnifyingGlassIcon,
PencilSimpleIcon,
TerminalIcon,
TrashIcon,
WarningDiamondIcon,
} from "@phosphor-icons/react";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation"; import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
import {
ContentCodeBlock,
ContentMessage,
} from "../../components/ToolAccordion/AccordionContent";
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
interface Props { interface Props {
part: ToolUIPart; part: ToolUIPart;
} }
/* ------------------------------------------------------------------ */
/* Tool name helpers */
/* ------------------------------------------------------------------ */
function extractToolName(part: ToolUIPart): string { function extractToolName(part: ToolUIPart): string {
// ToolUIPart.type is "tool-{name}", extract the name portion.
return part.type.replace(/^tool-/, ""); return part.type.replace(/^tool-/, "");
} }
function formatToolName(name: string): string { function formatToolName(name: string): string {
// "search_docs" → "Search docs", "Read" → "Read"
return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase()); return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase());
} }
function getAnimationText(part: ToolUIPart): string { /* ------------------------------------------------------------------ */
const label = formatToolName(extractToolName(part)); /* Tool categorization */
/* ------------------------------------------------------------------ */
type ToolCategory =
| "bash"
| "web"
| "file-read"
| "file-write"
| "file-delete"
| "file-list"
| "search"
| "edit"
| "todo"
| "other";
function getToolCategory(toolName: string): ToolCategory {
switch (toolName) {
case "bash_exec":
return "bash";
case "web_fetch":
case "WebSearch":
case "WebFetch":
return "web";
case "read_workspace_file":
case "Read":
return "file-read";
case "write_workspace_file":
case "Write":
return "file-write";
case "delete_workspace_file":
return "file-delete";
case "list_workspace_files":
case "Glob":
return "file-list";
case "Grep":
return "search";
case "Edit":
return "edit";
case "TodoWrite":
return "todo";
default:
return "other";
}
}
/* ------------------------------------------------------------------ */
/* Tool icon */
/* ------------------------------------------------------------------ */
function ToolIcon({
category,
isStreaming,
isError,
}: {
category: ToolCategory;
isStreaming: boolean;
isError: boolean;
}) {
if (isError) {
return (
<WarningDiamondIcon size={14} weight="regular" className="text-red-500" />
);
}
if (isStreaming) {
return <OrbitLoader size={14} />;
}
const iconClass = "text-neutral-400";
switch (category) {
case "bash":
return <TerminalIcon size={14} weight="regular" className={iconClass} />;
case "web":
return <GlobeIcon size={14} weight="regular" className={iconClass} />;
case "file-read":
return <FileIcon size={14} weight="regular" className={iconClass} />;
case "file-write":
return <FileIcon size={14} weight="regular" className={iconClass} />;
case "file-delete":
return <TrashIcon size={14} weight="regular" className={iconClass} />;
case "file-list":
return <FilesIcon size={14} weight="regular" className={iconClass} />;
case "search":
return (
<MagnifyingGlassIcon size={14} weight="regular" className={iconClass} />
);
case "edit":
return (
<PencilSimpleIcon size={14} weight="regular" className={iconClass} />
);
case "todo":
return (
<ListChecksIcon size={14} weight="regular" className={iconClass} />
);
default:
return <GearIcon size={14} weight="regular" className={iconClass} />;
}
}
/* ------------------------------------------------------------------ */
/* Accordion icon (larger, for the accordion header) */
/* ------------------------------------------------------------------ */
function AccordionIcon({ category }: { category: ToolCategory }) {
switch (category) {
case "bash":
return <TerminalIcon size={32} weight="light" />;
case "web":
return <GlobeIcon size={32} weight="light" />;
case "file-read":
case "file-write":
return <FileIcon size={32} weight="light" />;
case "file-delete":
return <TrashIcon size={32} weight="light" />;
case "file-list":
return <FilesIcon size={32} weight="light" />;
case "search":
return <MagnifyingGlassIcon size={32} weight="light" />;
case "edit":
return <PencilSimpleIcon size={32} weight="light" />;
case "todo":
return <ListChecksIcon size={32} weight="light" />;
default:
return <GearIcon size={32} weight="light" />;
}
}
/* ------------------------------------------------------------------ */
/* Input extraction */
/* ------------------------------------------------------------------ */
function getInputSummary(toolName: string, input: unknown): string | null {
if (!input || typeof input !== "object") return null;
const inp = input as Record<string, unknown>;
switch (toolName) {
case "bash_exec":
return typeof inp.command === "string" ? inp.command : null;
case "web_fetch":
case "WebFetch":
return typeof inp.url === "string" ? inp.url : null;
case "WebSearch":
return typeof inp.query === "string" ? inp.query : null;
case "read_workspace_file":
case "Read":
return (
(typeof inp.file_path === "string" ? inp.file_path : null) ??
(typeof inp.path === "string" ? inp.path : null)
);
case "write_workspace_file":
case "Write":
return (
(typeof inp.file_path === "string" ? inp.file_path : null) ??
(typeof inp.path === "string" ? inp.path : null)
);
case "delete_workspace_file":
return typeof inp.file_path === "string" ? inp.file_path : null;
case "Glob":
return typeof inp.pattern === "string" ? inp.pattern : null;
case "Grep":
return typeof inp.pattern === "string" ? inp.pattern : null;
case "Edit":
return typeof inp.file_path === "string" ? inp.file_path : null;
case "TodoWrite": {
// Extract the in-progress task name for the status line
const todos = Array.isArray(inp.todos) ? inp.todos : [];
const active = todos.find(
(t: Record<string, unknown>) => t.status === "in_progress",
);
if (active && typeof active.activeForm === "string")
return active.activeForm;
if (active && typeof active.content === "string") return active.content;
return null;
}
default:
return null;
}
}
function truncate(text: string, maxLen: number): string {
if (text.length <= maxLen) return text;
return text.slice(0, maxLen).trimEnd() + "…";
}
/* ------------------------------------------------------------------ */
/* Animation text */
/* ------------------------------------------------------------------ */
function getAnimationText(part: ToolUIPart, category: ToolCategory): string {
const toolName = extractToolName(part);
const summary = getInputSummary(toolName, part.input);
const shortSummary = summary ? truncate(summary, 60) : null;
switch (part.state) { switch (part.state) {
case "input-streaming": case "input-streaming":
case "input-available": case "input-available": {
return `Running ${label}`; switch (category) {
case "output-available": case "bash":
return `${label} completed`; return shortSummary ? `Running: ${shortSummary}` : "Running command…";
case "output-error": case "web":
return `${label} failed`; if (toolName === "WebSearch") {
return shortSummary
? `Searching "${shortSummary}"`
: "Searching the web…";
}
return shortSummary
? `Fetching ${shortSummary}`
: "Fetching web content…";
case "file-read":
return shortSummary ? `Reading ${shortSummary}` : "Reading file…";
case "file-write":
return shortSummary ? `Writing ${shortSummary}` : "Writing file…";
case "file-delete":
return shortSummary ? `Deleting ${shortSummary}` : "Deleting file…";
case "file-list":
return shortSummary ? `Listing ${shortSummary}` : "Listing files…";
case "search":
return shortSummary
? `Searching for "${shortSummary}"`
: "Searching…";
case "edit":
return shortSummary ? `Editing ${shortSummary}` : "Editing file…";
case "todo":
return shortSummary ? `${shortSummary}` : "Updating task list…";
default: default:
return `Running ${label}`; return `Running ${formatToolName(toolName)}`;
}
}
case "output-available": {
switch (category) {
case "bash": {
const exitCode = getExitCode(part.output);
if (exitCode !== null && exitCode !== 0) {
return `Command exited with code ${exitCode}`;
}
return shortSummary ? `Ran: ${shortSummary}` : "Command completed";
}
case "web":
if (toolName === "WebSearch") {
return shortSummary
? `Searched "${shortSummary}"`
: "Web search completed";
}
return shortSummary
? `Fetched ${shortSummary}`
: "Fetched web content";
case "file-read":
return shortSummary ? `Read ${shortSummary}` : "File read completed";
case "file-write":
return shortSummary ? `Wrote ${shortSummary}` : "File written";
case "file-delete":
return shortSummary ? `Deleted ${shortSummary}` : "File deleted";
case "file-list":
return "Listed files";
case "search":
return shortSummary
? `Searched for "${shortSummary}"`
: "Search completed";
case "edit":
return shortSummary ? `Edited ${shortSummary}` : "Edit completed";
case "todo":
return "Updated task list";
default:
return `${formatToolName(toolName)} completed`;
}
}
case "output-error": {
switch (category) {
case "bash":
return "Command failed";
case "web":
return toolName === "WebSearch" ? "Search failed" : "Fetch failed";
default:
return `${formatToolName(toolName)} failed`;
}
}
default:
return `Running ${formatToolName(toolName)}`;
} }
} }
/* ------------------------------------------------------------------ */
/* Output parsing helpers */
/* ------------------------------------------------------------------ */
function parseOutput(output: unknown): Record<string, unknown> | null {
if (!output) return null;
if (typeof output === "object") return output as Record<string, unknown>;
if (typeof output === "string") {
const trimmed = output.trim();
if (!trimmed) return null;
try {
const parsed = JSON.parse(trimmed);
if (
typeof parsed === "object" &&
parsed !== null &&
!Array.isArray(parsed)
)
return parsed;
} catch {
// Return as a message wrapper for plain text output
return { _raw: trimmed };
}
}
return null;
}
/**
* Extract text from MCP-style content blocks.
* SDK built-in tools (WebSearch, etc.) may return `{content: [{type:"text", text:"..."}]}`.
*/
function extractMcpText(output: Record<string, unknown>): string | null {
if (Array.isArray(output.content)) {
const texts = (output.content as Array<Record<string, unknown>>)
.filter((b) => b.type === "text" && typeof b.text === "string")
.map((b) => b.text as string);
if (texts.length > 0) return texts.join("\n");
}
return null;
}
function getExitCode(output: unknown): number | null {
const parsed = parseOutput(output);
if (!parsed) return null;
if (typeof parsed.exit_code === "number") return parsed.exit_code;
return null;
}
function getStringField(
obj: Record<string, unknown>,
...keys: string[]
): string | null {
for (const key of keys) {
if (typeof obj[key] === "string" && obj[key].length > 0)
return obj[key] as string;
}
return null;
}
/* ------------------------------------------------------------------ */
/* Accordion content per tool category */
/* ------------------------------------------------------------------ */
interface AccordionData {
title: string;
description?: string;
content: React.ReactNode;
}
function getBashAccordionData(
input: unknown,
output: Record<string, unknown>,
): AccordionData {
const inp = (input && typeof input === "object" ? input : {}) as Record<
string,
unknown
>;
const command = typeof inp.command === "string" ? inp.command : "Command";
const stdout = getStringField(output, "stdout");
const stderr = getStringField(output, "stderr");
const exitCode =
typeof output.exit_code === "number" ? output.exit_code : null;
const timedOut = output.timed_out === true;
const message = getStringField(output, "message");
const title = timedOut
? "Command timed out"
: exitCode !== null && exitCode !== 0
? `Command failed (exit ${exitCode})`
: "Command output";
return {
title,
description: truncate(command, 80),
content: (
<div className="space-y-2">
{stdout && (
<div>
<p className="mb-1 text-xs font-medium text-slate-500">stdout</p>
<ContentCodeBlock>{truncate(stdout, 2000)}</ContentCodeBlock>
</div>
)}
{stderr && (
<div>
<p className="mb-1 text-xs font-medium text-slate-500">stderr</p>
<ContentCodeBlock>{truncate(stderr, 1000)}</ContentCodeBlock>
</div>
)}
{!stdout && !stderr && message && (
<ContentMessage>{message}</ContentMessage>
)}
</div>
),
};
}
function getWebAccordionData(
input: unknown,
output: Record<string, unknown>,
): AccordionData {
const inp = (input && typeof input === "object" ? input : {}) as Record<
string,
unknown
>;
const url =
getStringField(inp as Record<string, unknown>, "url", "query") ??
"Web content";
// Try direct string fields first, then MCP content blocks, then raw JSON
let content = getStringField(output, "content", "text", "_raw");
if (!content) content = extractMcpText(output);
if (!content) {
// Fallback: render the raw JSON so the accordion isn't empty
try {
const raw = JSON.stringify(output, null, 2);
if (raw !== "{}") content = raw;
} catch {
/* empty */
}
}
const statusCode =
typeof output.status_code === "number" ? output.status_code : null;
const message = getStringField(output, "message");
return {
title: statusCode
? `Response (${statusCode})`
: url
? "Web fetch"
: "Search results",
description: truncate(url, 80),
content: content ? (
<ContentCodeBlock>{truncate(content, 2000)}</ContentCodeBlock>
) : message ? (
<ContentMessage>{message}</ContentMessage>
) : Object.keys(output).length > 0 ? (
<ContentCodeBlock>
{truncate(JSON.stringify(output, null, 2), 2000)}
</ContentCodeBlock>
) : null,
};
}
function getFileAccordionData(
input: unknown,
output: Record<string, unknown>,
): AccordionData {
const inp = (input && typeof input === "object" ? input : {}) as Record<
string,
unknown
>;
const filePath =
getStringField(
inp as Record<string, unknown>,
"file_path",
"path",
"pattern",
) ?? "File";
const content = getStringField(output, "content", "text", "_raw");
const message = getStringField(output, "message");
// For Glob/list results, try to show file list
const files = Array.isArray(output.files)
? output.files.filter((f: unknown): f is string => typeof f === "string")
: null;
return {
title: message ?? "File output",
description: truncate(filePath, 80),
content: (
<div className="space-y-2">
{content && (
<ContentCodeBlock>{truncate(content, 2000)}</ContentCodeBlock>
)}
{files && files.length > 0 && (
<ContentCodeBlock>
{truncate(files.join("\n"), 2000)}
</ContentCodeBlock>
)}
{!content && !files && message && (
<ContentMessage>{message}</ContentMessage>
)}
</div>
),
};
}
interface TodoItem {
content: string;
status: "pending" | "in_progress" | "completed";
activeForm?: string;
}
function getTodoAccordionData(input: unknown): AccordionData {
const inp = (input && typeof input === "object" ? input : {}) as Record<
string,
unknown
>;
const todos: TodoItem[] = Array.isArray(inp.todos)
? inp.todos.filter(
(t: unknown): t is TodoItem =>
typeof t === "object" &&
t !== null &&
typeof (t as TodoItem).content === "string",
)
: [];
const completed = todos.filter((t) => t.status === "completed").length;
const total = todos.length;
return {
title: "Task list",
description: `${completed}/${total} completed`,
content: (
<div className="space-y-1 py-1">
{todos.map((todo, i) => (
<div key={i} className="flex items-start gap-2 text-xs">
<span className="mt-0.5 flex-shrink-0">
{todo.status === "completed" ? (
<CheckCircleIcon
size={14}
weight="fill"
className="text-green-500"
/>
) : todo.status === "in_progress" ? (
<CircleDashedIcon
size={14}
weight="bold"
className="text-blue-500"
/>
) : (
<CircleIcon
size={14}
weight="regular"
className="text-neutral-400"
/>
)}
</span>
<span
className={
todo.status === "completed"
? "text-muted-foreground line-through"
: todo.status === "in_progress"
? "font-medium text-foreground"
: "text-muted-foreground"
}
>
{todo.content}
</span>
</div>
))}
</div>
),
};
}
function getDefaultAccordionData(
output: Record<string, unknown>,
): AccordionData {
const message = getStringField(output, "message");
const raw = output._raw;
const mcpText = extractMcpText(output);
let displayContent: string;
if (typeof raw === "string") {
displayContent = raw;
} else if (mcpText) {
displayContent = mcpText;
} else if (message) {
displayContent = message;
} else {
try {
displayContent = JSON.stringify(output, null, 2);
} catch {
displayContent = String(output);
}
}
return {
title: "Output",
description: message ?? undefined,
content: (
<ContentCodeBlock>{truncate(displayContent, 2000)}</ContentCodeBlock>
),
};
}
function getAccordionData(
category: ToolCategory,
input: unknown,
output: Record<string, unknown>,
): AccordionData {
switch (category) {
case "bash":
return getBashAccordionData(input, output);
case "web":
return getWebAccordionData(input, output);
case "file-read":
case "file-write":
case "file-delete":
case "file-list":
case "search":
case "edit":
return getFileAccordionData(input, output);
case "todo":
return getTodoAccordionData(input);
default:
return getDefaultAccordionData(output);
}
}
/* ------------------------------------------------------------------ */
/* Component */
/* ------------------------------------------------------------------ */
export function GenericTool({ part }: Props) { export function GenericTool({ part }: Props) {
const toolName = extractToolName(part);
const category = getToolCategory(toolName);
const isStreaming = const isStreaming =
part.state === "input-streaming" || part.state === "input-available"; part.state === "input-streaming" || part.state === "input-available";
const isError = part.state === "output-error"; const isError = part.state === "output-error";
const text = getAnimationText(part, category);
const output = parseOutput(part.output);
const hasOutput =
part.state === "output-available" &&
!!output &&
Object.keys(output).length > 0;
const hasError = isError && !!output;
// TodoWrite: always show accordion from input (the todo list lives in input)
const hasTodoInput =
category === "todo" &&
part.input &&
typeof part.input === "object" &&
Array.isArray((part.input as Record<string, unknown>).todos);
const showAccordion = hasOutput || hasError || hasTodoInput;
const accordionData = showAccordion
? getAccordionData(category, part.input, output ?? {})
: null;
return ( return (
<div className="py-2"> <div className="py-2">
<div className="flex items-center gap-2 text-sm text-muted-foreground"> <div className="flex items-center gap-2 text-sm text-muted-foreground">
<GearIcon <ToolIcon
size={14} category={category}
weight="regular" isStreaming={isStreaming}
className={ isError={isError}
isError
? "text-red-500"
: isStreaming
? "animate-spin text-neutral-500"
: "text-neutral-400"
}
/> />
<MorphingTextAnimation <MorphingTextAnimation
text={getAnimationText(part)} text={text}
className={isError ? "text-red-500" : undefined} className={isError ? "text-red-500" : undefined}
/> />
</div> </div>
{showAccordion && accordionData ? (
<ToolAccordion
icon={<AccordionIcon category={category} />}
title={accordionData.title}
description={accordionData.description}
titleClassName={isError ? "text-red-500" : undefined}
>
{accordionData.content}
</ToolAccordion>
) : null}
</div> </div>
); );
} }

View File

@@ -50,6 +50,14 @@ export function useChatSession() {
); );
}, [sessionQuery.data, sessionId]); }, [sessionQuery.data, sessionId]);
// Expose active_stream info so the caller can trigger manual resume
// after hydration completes (rather than relying on AI SDK's built-in
// resume which fires before hydration).
const hasActiveStream = useMemo(() => {
if (sessionQuery.data?.status !== 200) return false;
return !!sessionQuery.data.data.active_stream;
}, [sessionQuery.data]);
const { mutateAsync: createSessionMutation, isPending: isCreatingSession } = const { mutateAsync: createSessionMutation, isPending: isCreatingSession } =
usePostV2CreateSession({ usePostV2CreateSession({
mutation: { mutation: {
@@ -102,6 +110,7 @@ export function useChatSession() {
sessionId, sessionId,
setSessionId, setSessionId,
hydratedMessages, hydratedMessages,
hasActiveStream,
isLoadingSession: sessionQuery.isLoading, isLoadingSession: sessionQuery.isLoading,
createSession, createSession,
isCreatingSession, isCreatingSession,

View File

@@ -29,6 +29,7 @@ export function useCopilotPage() {
sessionId, sessionId,
setSessionId, setSessionId,
hydratedMessages, hydratedMessages,
hasActiveStream,
isLoadingSession, isLoadingSession,
createSession, createSession,
isCreatingSession, isCreatingSession,
@@ -80,14 +81,31 @@ export function useCopilotPage() {
}, },
}; };
}, },
// Resume: GET goes to the same URL as POST (backend uses
// method to distinguish). Override the default formula which
// would append /{chatId}/stream to the existing path.
prepareReconnectToStreamRequest: () => ({
api: `/api/chat/sessions/${sessionId}/stream`,
}),
}) })
: null, : null,
[sessionId], [sessionId],
); );
const { messages, sendMessage, stop, status, error, setMessages } = useChat({ const {
messages,
sendMessage,
stop,
status,
error,
setMessages,
resumeStream,
} = useChat({
id: sessionId ?? undefined, id: sessionId ?? undefined,
transport: transport ?? undefined, transport: transport ?? undefined,
// Don't use resume: true — it fires before hydration completes, causing
// the hydrated messages to overwrite the resumed stream. Instead we
// call resumeStream() manually after hydration + active_stream detection.
}); });
// Abort the stream if the backend doesn't start sending data within 12s. // Abort the stream if the backend doesn't start sending data within 12s.
@@ -108,13 +126,31 @@ export function useCopilotPage() {
return () => clearTimeout(timer); return () => clearTimeout(timer);
}, [status]); }, [status]);
// Hydrate messages from the REST session endpoint.
// Skip hydration while streaming to avoid overwriting the live stream.
useEffect(() => { useEffect(() => {
if (!hydratedMessages || hydratedMessages.length === 0) return; if (!hydratedMessages || hydratedMessages.length === 0) return;
if (status === "streaming" || status === "submitted") return;
setMessages((prev) => { setMessages((prev) => {
if (prev.length >= hydratedMessages.length) return prev; if (prev.length >= hydratedMessages.length) return prev;
return hydratedMessages; return hydratedMessages;
}); });
}, [hydratedMessages, setMessages]); }, [hydratedMessages, setMessages, status]);
// Resume an active stream AFTER hydration completes.
// The backend returns active_stream info when a task is still running.
// We wait for hydration so the AI SDK has the conversation history
// before the resumed stream appends the in-progress assistant message.
const hasResumedRef = useRef<string | null>(null);
useEffect(() => {
if (!hasActiveStream || !sessionId) return;
if (!hydratedMessages || hydratedMessages.length === 0) return;
if (status === "streaming" || status === "submitted") return;
// Only resume once per session to avoid re-triggering after stream ends
if (hasResumedRef.current === sessionId) return;
hasResumedRef.current = sessionId;
resumeStream();
}, [hasActiveStream, sessionId, hydratedMessages, status, resumeStream]);
// Poll session endpoint when a long-running tool (create_agent, edit_agent) // Poll session endpoint when a long-running tool (create_agent, edit_agent)
// is in progress. When the backend completes, the session data will contain // is in progress. When the backend completes, the session data will contain
@@ -197,12 +233,18 @@ export function useCopilotPage() {
} }
}, [isDeleting]); }, [isDeleting]);
// True while we know the backend has an active stream but haven't
// reconnected yet. Used to disable the send button and show stop UI.
const isReconnecting =
hasActiveStream && status !== "streaming" && status !== "submitted";
return { return {
sessionId, sessionId,
messages, messages,
status, status,
error, error,
stop, stop,
isReconnecting,
isLoadingSession, isLoadingSession,
isCreatingSession, isCreatingSession,
isUserLoading, isUserLoading,