mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-19 02:54:28 -05:00
Compare commits
10 Commits
dev
...
copilot/sd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ecfe4e6a7a | ||
|
|
efb4b3b518 | ||
|
|
ebeab7fbe6 | ||
|
|
98ef8a26ab | ||
|
|
ed02e6db9e | ||
|
|
6952334b85 | ||
|
|
0c586c2edf | ||
|
|
b6128dd75f | ||
|
|
c4f5f7c8b8 | ||
|
|
8af4e0bf7d |
@@ -4,7 +4,6 @@ This module contains the CoPilotExecutor class that consumes chat tasks from
|
|||||||
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
|
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ class SDKResponseAdapter:
|
|||||||
self.has_started_text = False
|
self.has_started_text = False
|
||||||
self.has_ended_text = False
|
self.has_ended_text = False
|
||||||
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
||||||
|
self.resolved_tool_calls: set[str] = set()
|
||||||
self.task_id: str | None = None
|
self.task_id: str | None = None
|
||||||
self.step_open = False
|
self.step_open = False
|
||||||
|
|
||||||
@@ -74,6 +75,10 @@ class SDKResponseAdapter:
|
|||||||
self.step_open = True
|
self.step_open = True
|
||||||
|
|
||||||
elif isinstance(sdk_message, AssistantMessage):
|
elif isinstance(sdk_message, AssistantMessage):
|
||||||
|
# Flush any SDK built-in tool calls that didn't get a UserMessage
|
||||||
|
# result (e.g. WebSearch, Read handled internally by the CLI).
|
||||||
|
self._flush_unresolved_tool_calls(responses)
|
||||||
|
|
||||||
# After tool results, the SDK sends a new AssistantMessage for the
|
# After tool results, the SDK sends a new AssistantMessage for the
|
||||||
# next LLM turn. Open a new step if the previous one was closed.
|
# next LLM turn. Open a new step if the previous one was closed.
|
||||||
if not self.step_open:
|
if not self.step_open:
|
||||||
@@ -111,6 +116,8 @@ class SDKResponseAdapter:
|
|||||||
# UserMessage carries tool results back from tool execution.
|
# UserMessage carries tool results back from tool execution.
|
||||||
content = sdk_message.content
|
content = sdk_message.content
|
||||||
blocks = content if isinstance(content, list) else []
|
blocks = content if isinstance(content, list) else []
|
||||||
|
resolved_in_blocks: set[str] = set()
|
||||||
|
|
||||||
for block in blocks:
|
for block in blocks:
|
||||||
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
||||||
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
|
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
|
||||||
@@ -132,6 +139,37 @@ class SDKResponseAdapter:
|
|||||||
success=not (block.is_error or False),
|
success=not (block.is_error or False),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
resolved_in_blocks.add(block.tool_use_id)
|
||||||
|
|
||||||
|
# Handle SDK built-in tool results carried via parent_tool_use_id
|
||||||
|
# instead of (or in addition to) ToolResultBlock content.
|
||||||
|
parent_id = sdk_message.parent_tool_use_id
|
||||||
|
if parent_id and parent_id not in resolved_in_blocks:
|
||||||
|
tool_info = self.current_tool_calls.get(parent_id, {})
|
||||||
|
tool_name = tool_info.get("name", "unknown")
|
||||||
|
|
||||||
|
# Try stashed output first (from PostToolUse hook),
|
||||||
|
# then tool_use_result dict, then string content.
|
||||||
|
output = pop_pending_tool_output(tool_name)
|
||||||
|
if not output:
|
||||||
|
tur = sdk_message.tool_use_result
|
||||||
|
if tur is not None:
|
||||||
|
output = _extract_tool_use_result(tur)
|
||||||
|
if not output and isinstance(content, str) and content.strip():
|
||||||
|
output = content.strip()
|
||||||
|
|
||||||
|
if output:
|
||||||
|
responses.append(
|
||||||
|
StreamToolOutputAvailable(
|
||||||
|
toolCallId=parent_id,
|
||||||
|
toolName=tool_name,
|
||||||
|
output=output,
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
resolved_in_blocks.add(parent_id)
|
||||||
|
|
||||||
|
self.resolved_tool_calls.update(resolved_in_blocks)
|
||||||
|
|
||||||
# Close the current step after tool results — the next
|
# Close the current step after tool results — the next
|
||||||
# AssistantMessage will open a new step for the continuation.
|
# AssistantMessage will open a new step for the continuation.
|
||||||
@@ -140,6 +178,7 @@ class SDKResponseAdapter:
|
|||||||
self.step_open = False
|
self.step_open = False
|
||||||
|
|
||||||
elif isinstance(sdk_message, ResultMessage):
|
elif isinstance(sdk_message, ResultMessage):
|
||||||
|
self._flush_unresolved_tool_calls(responses)
|
||||||
self._end_text_if_open(responses)
|
self._end_text_if_open(responses)
|
||||||
# Close the step before finishing.
|
# Close the step before finishing.
|
||||||
if self.step_open:
|
if self.step_open:
|
||||||
@@ -149,7 +188,7 @@ class SDKResponseAdapter:
|
|||||||
if sdk_message.subtype == "success":
|
if sdk_message.subtype == "success":
|
||||||
responses.append(StreamFinish())
|
responses.append(StreamFinish())
|
||||||
elif sdk_message.subtype in ("error", "error_during_execution"):
|
elif sdk_message.subtype in ("error", "error_during_execution"):
|
||||||
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
|
error_msg = sdk_message.result or "Unknown error"
|
||||||
responses.append(
|
responses.append(
|
||||||
StreamError(errorText=str(error_msg), code="sdk_error")
|
StreamError(errorText=str(error_msg), code="sdk_error")
|
||||||
)
|
)
|
||||||
@@ -180,6 +219,59 @@ class SDKResponseAdapter:
|
|||||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||||
self.has_ended_text = True
|
self.has_ended_text = True
|
||||||
|
|
||||||
|
def _flush_unresolved_tool_calls(self, responses: list[StreamBaseResponse]) -> None:
|
||||||
|
"""Emit outputs for tool calls that didn't receive a UserMessage result.
|
||||||
|
|
||||||
|
SDK built-in tools (WebSearch, Read, etc.) may be executed by the CLI
|
||||||
|
internally without surfacing a separate ``UserMessage`` with
|
||||||
|
``ToolResultBlock`` content. The ``PostToolUse`` hook stashes their
|
||||||
|
output, which we pop and emit here before the next ``AssistantMessage``
|
||||||
|
starts.
|
||||||
|
"""
|
||||||
|
flushed = False
|
||||||
|
for tool_id, tool_info in self.current_tool_calls.items():
|
||||||
|
if tool_id in self.resolved_tool_calls:
|
||||||
|
continue
|
||||||
|
tool_name = tool_info.get("name", "unknown")
|
||||||
|
output = pop_pending_tool_output(tool_name)
|
||||||
|
if output is not None:
|
||||||
|
responses.append(
|
||||||
|
StreamToolOutputAvailable(
|
||||||
|
toolCallId=tool_id,
|
||||||
|
toolName=tool_name,
|
||||||
|
output=output,
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.resolved_tool_calls.add(tool_id)
|
||||||
|
flushed = True
|
||||||
|
logger.debug(
|
||||||
|
f"Flushed pending output for built-in tool {tool_name} "
|
||||||
|
f"(call {tool_id})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No output available — emit an empty output so the frontend
|
||||||
|
# transitions the tool from input-available to output-available
|
||||||
|
# (stops the spinner).
|
||||||
|
responses.append(
|
||||||
|
StreamToolOutputAvailable(
|
||||||
|
toolCallId=tool_id,
|
||||||
|
toolName=tool_name,
|
||||||
|
output="",
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.resolved_tool_calls.add(tool_id)
|
||||||
|
flushed = True
|
||||||
|
logger.debug(
|
||||||
|
f"Flushed empty output for unresolved tool {tool_name} "
|
||||||
|
f"(call {tool_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if flushed and self.step_open:
|
||||||
|
responses.append(StreamFinishStep())
|
||||||
|
self.step_open = False
|
||||||
|
|
||||||
|
|
||||||
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
||||||
"""Extract a string output from a ToolResultBlock's content field."""
|
"""Extract a string output from a ToolResultBlock's content field."""
|
||||||
@@ -199,3 +291,30 @@ def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
|||||||
return json.dumps(content)
|
return json.dumps(content)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return str(content)
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tool_use_result(result: object) -> str:
|
||||||
|
"""Extract a string from a UserMessage's ``tool_use_result`` dict.
|
||||||
|
|
||||||
|
SDK built-in tools may store their result in ``tool_use_result``
|
||||||
|
instead of (or in addition to) ``ToolResultBlock`` content blocks.
|
||||||
|
"""
|
||||||
|
if isinstance(result, str):
|
||||||
|
return result
|
||||||
|
if isinstance(result, dict):
|
||||||
|
# Try common result keys
|
||||||
|
for key in ("content", "text", "output", "stdout", "result"):
|
||||||
|
val = result.get(key)
|
||||||
|
if isinstance(val, str) and val:
|
||||||
|
return val
|
||||||
|
# Fall back to JSON serialization of the whole dict
|
||||||
|
try:
|
||||||
|
return json.dumps(result)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return str(result)
|
||||||
|
if result is None:
|
||||||
|
return ""
|
||||||
|
try:
|
||||||
|
return json.dumps(result)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return str(result)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from .tool_adapter import (
|
|||||||
DANGEROUS_PATTERNS,
|
DANGEROUS_PATTERNS,
|
||||||
MCP_TOOL_PREFIX,
|
MCP_TOOL_PREFIX,
|
||||||
WORKSPACE_SCOPED_TOOLS,
|
WORKSPACE_SCOPED_TOOLS,
|
||||||
|
stash_pending_tool_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -224,10 +225,25 @@ def create_security_hooks(
|
|||||||
tool_use_id: str | None,
|
tool_use_id: str | None,
|
||||||
context: HookContext,
|
context: HookContext,
|
||||||
) -> SyncHookJSONOutput:
|
) -> SyncHookJSONOutput:
|
||||||
"""Log successful tool executions for observability."""
|
"""Log successful tool executions and stash SDK built-in tool outputs.
|
||||||
|
|
||||||
|
MCP tools stash their output in ``_execute_tool_sync`` before the
|
||||||
|
SDK can truncate it. SDK built-in tools (WebSearch, Read, etc.)
|
||||||
|
are executed by the CLI internally — this hook captures their
|
||||||
|
output so the response adapter can forward it to the frontend.
|
||||||
|
"""
|
||||||
_ = context
|
_ = context
|
||||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||||
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
||||||
|
|
||||||
|
# Stash output for SDK built-in tools so the response adapter can
|
||||||
|
# emit StreamToolOutputAvailable even when the CLI doesn't surface
|
||||||
|
# a separate UserMessage with ToolResultBlock content.
|
||||||
|
if not tool_name.startswith(MCP_TOOL_PREFIX):
|
||||||
|
tool_response = input_data.get("tool_response")
|
||||||
|
if tool_response is not None:
|
||||||
|
stash_pending_tool_output(tool_name, tool_response)
|
||||||
|
|
||||||
return cast(SyncHookJSONOutput, {})
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
async def post_tool_failure_hook(
|
async def post_tool_failure_hook(
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ from .tool_adapter import (
|
|||||||
set_execution_context,
|
set_execution_context,
|
||||||
)
|
)
|
||||||
from .transcript import (
|
from .transcript import (
|
||||||
|
cleanup_cli_project_dir,
|
||||||
download_transcript,
|
download_transcript,
|
||||||
read_transcript_file,
|
read_transcript_file,
|
||||||
upload_transcript,
|
upload_transcript,
|
||||||
@@ -86,9 +87,12 @@ _SDK_TOOL_SUPPLEMENT = """
|
|||||||
for shell commands — it runs in a network-isolated sandbox.
|
for shell commands — it runs in a network-isolated sandbox.
|
||||||
- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the
|
- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the
|
||||||
same working directory. Files created by one are readable by the other.
|
same working directory. Files created by one are readable by the other.
|
||||||
These files are **ephemeral** — they exist only for the current session.
|
- **IMPORTANT — File persistence**: Your working directory is **ephemeral** —
|
||||||
- **Persistent storage**: Use `write_workspace_file` / `read_workspace_file`
|
files are lost between turns. When you create or modify important files
|
||||||
for files that should persist across sessions (stored in cloud storage).
|
(code, configs, outputs), you MUST save them using `write_workspace_file`
|
||||||
|
so they persist. Use `read_workspace_file` and `list_workspace_files` to
|
||||||
|
access files saved in previous turns. If a "Files from previous turns"
|
||||||
|
section is present above, those files are available via `read_workspace_file`.
|
||||||
- Long-running tools (create_agent, edit_agent, etc.) are handled
|
- Long-running tools (create_agent, edit_agent, etc.) are handled
|
||||||
asynchronously. You will receive an immediate response; the actual result
|
asynchronously. You will receive an immediate response; the actual result
|
||||||
is delivered to the user via a background stream.
|
is delivered to the user via a background stream.
|
||||||
@@ -268,48 +272,28 @@ def _make_sdk_cwd(session_id: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _cleanup_sdk_tool_results(cwd: str) -> None:
|
def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||||
"""Remove SDK tool-result files for a specific session working directory.
|
"""Remove SDK session artifacts for a specific working directory.
|
||||||
|
|
||||||
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
|
Cleans up:
|
||||||
We clean only the specific cwd's results to avoid race conditions between
|
- ``~/.claude/projects/<encoded-cwd>/`` — CLI session transcripts and
|
||||||
concurrent sessions.
|
tool-result files. Each SDK turn uses a unique cwd, so this directory
|
||||||
|
is safe to remove entirely.
|
||||||
|
- ``/tmp/copilot-<session>/`` — the ephemeral working directory.
|
||||||
|
|
||||||
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id.
|
Security: *cwd* MUST be created by ``_make_sdk_cwd()`` which sanitizes
|
||||||
|
the session_id.
|
||||||
"""
|
"""
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
# Validate cwd is under the expected prefix
|
|
||||||
normalized = os.path.normpath(cwd)
|
normalized = os.path.normpath(cwd)
|
||||||
if not normalized.startswith(_SDK_CWD_PREFIX):
|
if not normalized.startswith(_SDK_CWD_PREFIX):
|
||||||
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
|
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# SDK encodes the cwd path by replacing '/' with '-'
|
# Clean the CLI's project directory (transcripts + tool-results).
|
||||||
encoded_cwd = normalized.replace("/", "-")
|
cleanup_cli_project_dir(cwd)
|
||||||
|
|
||||||
# Construct the project directory path (known-safe home expansion)
|
# Clean up the temp cwd directory itself.
|
||||||
claude_projects = os.path.expanduser("~/.claude/projects")
|
|
||||||
project_dir = os.path.join(claude_projects, encoded_cwd)
|
|
||||||
|
|
||||||
# Security check 3: Validate project_dir is under ~/.claude/projects
|
|
||||||
project_dir = os.path.normpath(project_dir)
|
|
||||||
if not project_dir.startswith(claude_projects):
|
|
||||||
logger.warning(
|
|
||||||
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
results_dir = os.path.join(project_dir, "tool-results")
|
|
||||||
if os.path.isdir(results_dir):
|
|
||||||
for filename in os.listdir(results_dir):
|
|
||||||
file_path = os.path.join(results_dir, filename)
|
|
||||||
try:
|
|
||||||
if os.path.isfile(file_path):
|
|
||||||
os.remove(file_path)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Also clean up the temp cwd directory itself
|
|
||||||
try:
|
try:
|
||||||
shutil.rmtree(normalized, ignore_errors=True)
|
shutil.rmtree(normalized, ignore_errors=True)
|
||||||
except OSError:
|
except OSError:
|
||||||
@@ -519,6 +503,7 @@ async def stream_chat_completion_sdk(
|
|||||||
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
|
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
|
||||||
captured_transcript.path = transcript_path
|
captured_transcript.path = transcript_path
|
||||||
captured_transcript.sdk_session_id = sdk_session_id
|
captured_transcript.sdk_session_id = sdk_session_id
|
||||||
|
logger.debug(f"[SDK] Stop hook: path={transcript_path!r}")
|
||||||
|
|
||||||
security_hooks = create_security_hooks(
|
security_hooks = create_security_hooks(
|
||||||
user_id,
|
user_id,
|
||||||
@@ -530,18 +515,20 @@ async def stream_chat_completion_sdk(
|
|||||||
# --- Resume strategy: download transcript from bucket ---
|
# --- Resume strategy: download transcript from bucket ---
|
||||||
resume_file: str | None = None
|
resume_file: str | None = None
|
||||||
use_resume = False
|
use_resume = False
|
||||||
|
transcript_msg_count = 0 # watermark: session.messages length at upload
|
||||||
|
|
||||||
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||||
transcript_content = await download_transcript(user_id, session_id)
|
dl = await download_transcript(user_id, session_id)
|
||||||
if transcript_content and validate_transcript(transcript_content):
|
if dl and validate_transcript(dl.content):
|
||||||
resume_file = write_transcript_to_tempfile(
|
resume_file = write_transcript_to_tempfile(
|
||||||
transcript_content, session_id, sdk_cwd
|
dl.content, session_id, sdk_cwd
|
||||||
)
|
)
|
||||||
if resume_file:
|
if resume_file:
|
||||||
use_resume = True
|
use_resume = True
|
||||||
logger.info(
|
transcript_msg_count = dl.message_count
|
||||||
f"[SDK] Using --resume with transcript "
|
logger.debug(
|
||||||
f"({len(transcript_content)} bytes)"
|
f"[SDK] Using --resume ({len(dl.content)}B, "
|
||||||
|
f"msg_count={transcript_msg_count})"
|
||||||
)
|
)
|
||||||
|
|
||||||
sdk_options_kwargs: dict[str, Any] = {
|
sdk_options_kwargs: dict[str, Any] = {
|
||||||
@@ -582,11 +569,35 @@ async def stream_chat_completion_sdk(
|
|||||||
# Build query: with --resume the CLI already has full
|
# Build query: with --resume the CLI already has full
|
||||||
# context, so we only send the new message. Without
|
# context, so we only send the new message. Without
|
||||||
# resume, compress history into a context prefix.
|
# resume, compress history into a context prefix.
|
||||||
|
#
|
||||||
|
# Hybrid mode: if the transcript is stale (upload missed
|
||||||
|
# some turns), compress only the gap and prepend it so
|
||||||
|
# the agent has transcript context + missed turns.
|
||||||
query_message = current_message
|
query_message = current_message
|
||||||
if not use_resume and len(session.messages) > 1:
|
current_msg_count = len(session.messages)
|
||||||
|
|
||||||
|
if use_resume and transcript_msg_count >= 0:
|
||||||
|
# Transcript covers messages[0..M-1]. Current session
|
||||||
|
# has N messages (last one is the new user msg).
|
||||||
|
# Gap = messages[M .. N-2] (everything between upload
|
||||||
|
# and the current turn).
|
||||||
|
if transcript_msg_count < current_msg_count - 1:
|
||||||
|
gap = session.messages[transcript_msg_count:-1]
|
||||||
|
gap_context = _format_conversation_context(gap)
|
||||||
|
if gap_context:
|
||||||
|
logger.info(
|
||||||
|
f"[SDK] Transcript stale: covers {transcript_msg_count} "
|
||||||
|
f"of {current_msg_count} messages, compressing "
|
||||||
|
f"{len(gap)} missed messages"
|
||||||
|
)
|
||||||
|
query_message = (
|
||||||
|
f"{gap_context}\n\n"
|
||||||
|
f"Now, the user says:\n{current_message}"
|
||||||
|
)
|
||||||
|
elif not use_resume and current_msg_count > 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[SDK] Using compression fallback for session "
|
f"[SDK] Using compression fallback for session "
|
||||||
f"{session_id} ({len(session.messages)} messages) — "
|
f"{session_id} ({current_msg_count} messages) — "
|
||||||
f"no transcript available for --resume"
|
f"no transcript available for --resume"
|
||||||
)
|
)
|
||||||
compressed = await _compress_conversation_history(session)
|
compressed = await _compress_conversation_history(session)
|
||||||
@@ -597,10 +608,10 @@ async def stream_chat_completion_sdk(
|
|||||||
f"Now, the user says:\n{current_message}"
|
f"Now, the user says:\n{current_message}"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[SDK] Sending query ({len(session.messages)} msgs in session)"
|
f"[SDK] Sending query ({len(session.messages)} msgs, "
|
||||||
|
f"resume={use_resume})"
|
||||||
)
|
)
|
||||||
logger.debug(f"[SDK] Query preview: {current_message[:80]!r}")
|
|
||||||
await client.query(query_message, session_id=session_id)
|
await client.query(query_message, session_id=session_id)
|
||||||
|
|
||||||
assistant_response = ChatMessage(role="assistant", content="")
|
assistant_response = ChatMessage(role="assistant", content="")
|
||||||
@@ -681,29 +692,33 @@ async def stream_chat_completion_sdk(
|
|||||||
) and not has_appended_assistant:
|
) and not has_appended_assistant:
|
||||||
session.messages.append(assistant_response)
|
session.messages.append(assistant_response)
|
||||||
|
|
||||||
# --- Capture transcript while CLI is still alive ---
|
# --- Upload transcript for next-turn --resume ---
|
||||||
# Must happen INSIDE async with: close() sends SIGTERM
|
# After async with the SDK task group has exited, so the Stop
|
||||||
# which kills the CLI before it can flush the JSONL.
|
# hook has already fired and the CLI has been SIGTERMed. The
|
||||||
if (
|
# CLI uses appendFileSync, so all writes are safely on disk.
|
||||||
config.claude_agent_use_resume
|
if config.claude_agent_use_resume and user_id:
|
||||||
and user_id
|
# With --resume the CLI appends to the resume file (most
|
||||||
and captured_transcript.available
|
# complete). Otherwise use the Stop hook path.
|
||||||
):
|
if use_resume and resume_file:
|
||||||
# Give CLI time to flush JSONL writes before we read
|
raw_transcript = read_transcript_file(resume_file)
|
||||||
await asyncio.sleep(0.5)
|
elif captured_transcript.path:
|
||||||
raw_transcript = read_transcript_file(captured_transcript.path)
|
raw_transcript = read_transcript_file(captured_transcript.path)
|
||||||
if raw_transcript:
|
else:
|
||||||
try:
|
raw_transcript = None
|
||||||
async with asyncio.timeout(30):
|
|
||||||
await _upload_transcript_bg(
|
if raw_transcript:
|
||||||
user_id, session_id, raw_transcript
|
# Shield the upload from generator cancellation so a
|
||||||
)
|
# client disconnect / page refresh doesn't lose the
|
||||||
except asyncio.TimeoutError:
|
# transcript. The upload must finish even if the SSE
|
||||||
logger.warning(
|
# connection is torn down.
|
||||||
f"[SDK] Transcript upload timed out for {session_id}"
|
await asyncio.shield(
|
||||||
)
|
_try_upload_transcript(
|
||||||
else:
|
user_id,
|
||||||
logger.debug("[SDK] Stop hook fired but transcript not usable")
|
session_id,
|
||||||
|
raw_transcript,
|
||||||
|
message_count=len(session.messages),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -712,7 +727,7 @@ async def stream_chat_completion_sdk(
|
|||||||
"to use the OpenAI-compatible fallback."
|
"to use the OpenAI-compatible fallback."
|
||||||
)
|
)
|
||||||
|
|
||||||
await upsert_chat_session(session)
|
await asyncio.shield(upsert_chat_session(session))
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
|
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
|
||||||
)
|
)
|
||||||
@@ -722,7 +737,7 @@ async def stream_chat_completion_sdk(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[SDK] Error: {e}", exc_info=True)
|
logger.error(f"[SDK] Error: {e}", exc_info=True)
|
||||||
try:
|
try:
|
||||||
await upsert_chat_session(session)
|
await asyncio.shield(upsert_chat_session(session))
|
||||||
except Exception as save_err:
|
except Exception as save_err:
|
||||||
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
||||||
yield StreamError(
|
yield StreamError(
|
||||||
@@ -735,14 +750,31 @@ async def stream_chat_completion_sdk(
|
|||||||
_cleanup_sdk_tool_results(sdk_cwd)
|
_cleanup_sdk_tool_results(sdk_cwd)
|
||||||
|
|
||||||
|
|
||||||
async def _upload_transcript_bg(
|
async def _try_upload_transcript(
|
||||||
user_id: str, session_id: str, raw_content: str
|
user_id: str,
|
||||||
) -> None:
|
session_id: str,
|
||||||
"""Background task to strip progress entries and upload transcript."""
|
raw_content: str,
|
||||||
|
message_count: int = 0,
|
||||||
|
) -> bool:
|
||||||
|
"""Strip progress entries and upload transcript (with timeout).
|
||||||
|
|
||||||
|
Returns True if the upload completed without error.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
await upload_transcript(user_id, session_id, raw_content)
|
async with asyncio.timeout(30):
|
||||||
|
await upload_transcript(
|
||||||
|
user_id, session_id, raw_content, message_count=message_count
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"[SDK] Transcript upload timed out for {session_id}")
|
||||||
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[SDK] Failed to upload transcript for {session_id}: {e}")
|
logger.error(
|
||||||
|
f"[SDK] Failed to upload transcript for {session_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def _update_title_async(
|
async def _update_title_async(
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ _current_session: ContextVar[ChatSession | None] = ContextVar(
|
|||||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||||
# response adapter when it builds StreamToolOutputAvailable.
|
# response adapter when it builds StreamToolOutputAvailable.
|
||||||
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar(
|
_pending_tool_outputs: ContextVar[dict[str, list[str]]] = ContextVar(
|
||||||
"pending_tool_outputs", default=None # type: ignore[arg-type]
|
"pending_tool_outputs", default=None # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -88,19 +88,52 @@ def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
|||||||
|
|
||||||
|
|
||||||
def pop_pending_tool_output(tool_name: str) -> str | None:
|
def pop_pending_tool_output(tool_name: str) -> str | None:
|
||||||
"""Pop and return the stashed full output for *tool_name*.
|
"""Pop and return the oldest stashed output for *tool_name*.
|
||||||
|
|
||||||
The SDK CLI may truncate large tool results (writing them to disk and
|
The SDK CLI may truncate large tool results (writing them to disk and
|
||||||
replacing the content with a file reference). This stash keeps the
|
replacing the content with a file reference). This stash keeps the
|
||||||
original MCP output so the response adapter can forward it to the
|
original MCP output so the response adapter can forward it to the
|
||||||
frontend for proper widget rendering.
|
frontend for proper widget rendering.
|
||||||
|
|
||||||
|
Uses a FIFO queue per tool name so duplicate calls to the same tool
|
||||||
|
in one turn each get their own output.
|
||||||
|
|
||||||
Returns ``None`` if nothing was stashed for *tool_name*.
|
Returns ``None`` if nothing was stashed for *tool_name*.
|
||||||
"""
|
"""
|
||||||
pending = _pending_tool_outputs.get(None)
|
pending = _pending_tool_outputs.get(None)
|
||||||
if pending is None:
|
if pending is None:
|
||||||
return None
|
return None
|
||||||
return pending.pop(tool_name, None)
|
queue = pending.get(tool_name)
|
||||||
|
if not queue:
|
||||||
|
pending.pop(tool_name, None)
|
||||||
|
return None
|
||||||
|
value = queue.pop(0)
|
||||||
|
if not queue:
|
||||||
|
del pending[tool_name]
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def stash_pending_tool_output(tool_name: str, output: Any) -> None:
|
||||||
|
"""Stash tool output for later retrieval by the response adapter.
|
||||||
|
|
||||||
|
Used by the PostToolUse hook to capture SDK built-in tool outputs
|
||||||
|
(WebSearch, Read, etc.) that aren't available through the MCP stash
|
||||||
|
mechanism in ``_execute_tool_sync``.
|
||||||
|
|
||||||
|
Appends to a FIFO queue per tool name so multiple calls to the same
|
||||||
|
tool in one turn are all preserved.
|
||||||
|
"""
|
||||||
|
pending = _pending_tool_outputs.get(None)
|
||||||
|
if pending is None:
|
||||||
|
return
|
||||||
|
if isinstance(output, str):
|
||||||
|
text = output
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
text = json.dumps(output)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
text = str(output)
|
||||||
|
pending.setdefault(tool_name, []).append(text)
|
||||||
|
|
||||||
|
|
||||||
async def _execute_tool_sync(
|
async def _execute_tool_sync(
|
||||||
@@ -125,14 +158,63 @@ async def _execute_tool_sync(
|
|||||||
# Stash the full output before the SDK potentially truncates it.
|
# Stash the full output before the SDK potentially truncates it.
|
||||||
pending = _pending_tool_outputs.get(None)
|
pending = _pending_tool_outputs.get(None)
|
||||||
if pending is not None:
|
if pending is not None:
|
||||||
pending[base_tool.name] = text
|
pending.setdefault(base_tool.name, []).append(text)
|
||||||
|
|
||||||
|
content_blocks: list[dict[str, str]] = [{"type": "text", "text": text}]
|
||||||
|
|
||||||
|
# If the tool result contains inline image data, add an MCP image block
|
||||||
|
# so Claude can "see" the image (e.g. read_workspace_file on a small PNG).
|
||||||
|
image_block = _extract_image_block(text)
|
||||||
|
if image_block:
|
||||||
|
content_blocks.append(image_block)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"content": [{"type": "text", "text": text}],
|
"content": content_blocks,
|
||||||
"isError": not result.success,
|
"isError": not result.success,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# MIME types that Claude can process as image content blocks.
|
||||||
|
_SUPPORTED_IMAGE_TYPES = frozenset(
|
||||||
|
{"image/png", "image/jpeg", "image/gif", "image/webp"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_image_block(text: str) -> dict[str, str] | None:
|
||||||
|
"""Extract an MCP image content block from a tool result JSON string.
|
||||||
|
|
||||||
|
Detects workspace file responses with ``content_base64`` and an image
|
||||||
|
MIME type, returning an MCP-format image block that allows Claude to
|
||||||
|
"see" the image. Returns ``None`` if the result is not an inline image.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = json.loads(text)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
mime_type = data.get("mime_type", "")
|
||||||
|
base64_content = data.get("content_base64", "")
|
||||||
|
|
||||||
|
# Only inline small images — large ones would exceed Claude's limits.
|
||||||
|
# 32 KB raw ≈ ~43 KB base64.
|
||||||
|
_MAX_IMAGE_BASE64_BYTES = 43_000
|
||||||
|
if (
|
||||||
|
mime_type in _SUPPORTED_IMAGE_TYPES
|
||||||
|
and base64_content
|
||||||
|
and len(base64_content) <= _MAX_IMAGE_BASE64_BYTES
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"type": "image",
|
||||||
|
"data": base64_content,
|
||||||
|
"mimeType": mime_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _mcp_error(message: str) -> dict[str, Any]:
|
def _mcp_error(message: str) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"content": [
|
"content": [
|
||||||
@@ -311,14 +393,29 @@ def create_copilot_mcp_server():
|
|||||||
# which provides kernel-level network isolation via unshare --net.
|
# which provides kernel-level network isolation via unshare --net.
|
||||||
# Task allows spawning sub-agents (rate-limited by security hooks).
|
# Task allows spawning sub-agents (rate-limited by security hooks).
|
||||||
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
|
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
|
||||||
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task", "WebSearch"]
|
# TodoWrite manages the task checklist shown in the UI — no security concern.
|
||||||
|
_SDK_BUILTIN_TOOLS = [
|
||||||
|
"Read",
|
||||||
|
"Write",
|
||||||
|
"Edit",
|
||||||
|
"Glob",
|
||||||
|
"Grep",
|
||||||
|
"Task",
|
||||||
|
"WebSearch",
|
||||||
|
"TodoWrite",
|
||||||
|
]
|
||||||
|
|
||||||
# SDK built-in tools that must be explicitly blocked.
|
# SDK built-in tools that must be explicitly blocked.
|
||||||
# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level
|
# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level
|
||||||
# network isolation (unshare --net) instead.
|
# network isolation (unshare --net) instead.
|
||||||
# WebFetch: SSRF risk — can reach internal network (localhost, 10.x, etc.).
|
# WebFetch: SSRF risk — can reach internal network (localhost, 10.x, etc.).
|
||||||
# Agent uses the SSRF-protected mcp__copilot__web_fetch tool instead.
|
# Agent uses the SSRF-protected mcp__copilot__web_fetch tool instead.
|
||||||
SDK_DISALLOWED_TOOLS = ["Bash", "WebFetch"]
|
# AskUserQuestion: interactive CLI tool — no terminal in copilot context.
|
||||||
|
SDK_DISALLOWED_TOOLS = [
|
||||||
|
"Bash",
|
||||||
|
"WebFetch",
|
||||||
|
"AskUserQuestion",
|
||||||
|
]
|
||||||
|
|
||||||
# Tools that are blocked entirely in security hooks (defence-in-depth).
|
# Tools that are blocked entirely in security hooks (defence-in-depth).
|
||||||
# Includes SDK_DISALLOWED_TOOLS plus common aliases/synonyms.
|
# Includes SDK_DISALLOWED_TOOLS plus common aliases/synonyms.
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -31,6 +33,16 @@ STRIPPABLE_TYPES = frozenset(
|
|||||||
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
|
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TranscriptDownload:
|
||||||
|
"""Result of downloading a transcript with its metadata."""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
message_count: int = 0 # session.messages length when uploaded
|
||||||
|
uploaded_at: float = 0.0 # epoch timestamp of upload
|
||||||
|
|
||||||
|
|
||||||
# Workspace storage constants — deterministic path from session_id.
|
# Workspace storage constants — deterministic path from session_id.
|
||||||
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
|
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
|
||||||
|
|
||||||
@@ -119,23 +131,19 @@ def read_transcript_file(transcript_path: str) -> str | None:
|
|||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
if not content.strip():
|
if not content.strip():
|
||||||
logger.debug(f"[Transcript] Empty file: {transcript_path}")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
lines = content.strip().split("\n")
|
lines = content.strip().split("\n")
|
||||||
if len(lines) < 3:
|
if len(lines) < 3:
|
||||||
# Raw files with ≤2 lines are metadata-only
|
# Raw files with ≤2 lines are metadata-only
|
||||||
# (queue-operation + file-history-snapshot, no conversation).
|
# (queue-operation + file-history-snapshot, no conversation).
|
||||||
logger.debug(
|
|
||||||
f"[Transcript] Too few lines ({len(lines)}): {transcript_path}"
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Quick structural validation — parse first and last lines.
|
# Quick structural validation — parse first and last lines.
|
||||||
json.loads(lines[0])
|
json.loads(lines[0])
|
||||||
json.loads(lines[-1])
|
json.loads(lines[-1])
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[Transcript] Read {len(lines)} lines, "
|
f"[Transcript] Read {len(lines)} lines, "
|
||||||
f"{len(content)} bytes from {transcript_path}"
|
f"{len(content)} bytes from {transcript_path}"
|
||||||
)
|
)
|
||||||
@@ -160,6 +168,41 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
|||||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||||
|
"""Encode a working directory path the same way the Claude CLI does.
|
||||||
|
|
||||||
|
The CLI replaces all non-alphanumeric characters with ``-``.
|
||||||
|
"""
|
||||||
|
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||||
|
"""Remove the CLI's project directory for a specific working directory.
|
||||||
|
|
||||||
|
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
|
||||||
|
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
|
||||||
|
safe to remove entirely after the transcript has been uploaded.
|
||||||
|
"""
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
cwd_encoded = _encode_cwd_for_cli(sdk_cwd)
|
||||||
|
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||||
|
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
|
||||||
|
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||||
|
|
||||||
|
if not project_dir.startswith(projects_base + os.sep):
|
||||||
|
logger.warning(
|
||||||
|
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if os.path.isdir(project_dir):
|
||||||
|
shutil.rmtree(project_dir, ignore_errors=True)
|
||||||
|
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"[Transcript] Project dir not found: {project_dir}")
|
||||||
|
|
||||||
|
|
||||||
def write_transcript_to_tempfile(
|
def write_transcript_to_tempfile(
|
||||||
transcript_content: str,
|
transcript_content: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
@@ -191,7 +234,7 @@ def write_transcript_to_tempfile(
|
|||||||
with open(jsonl_path, "w") as f:
|
with open(jsonl_path, "w") as f:
|
||||||
f.write(transcript_content)
|
f.write(transcript_content)
|
||||||
|
|
||||||
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
logger.debug(f"[Transcript] Wrote resume file: {jsonl_path}")
|
||||||
return jsonl_path
|
return jsonl_path
|
||||||
|
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
@@ -248,6 +291,15 @@ def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||||
|
"""Return (workspace_id, file_id, filename) for a session's transcript metadata."""
|
||||||
|
return (
|
||||||
|
TRANSCRIPT_STORAGE_PREFIX,
|
||||||
|
_sanitize_id(user_id),
|
||||||
|
f"{_sanitize_id(session_id)}.meta.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||||
"""Build the full storage path string that ``retrieve()`` expects.
|
"""Build the full storage path string that ``retrieve()`` expects.
|
||||||
|
|
||||||
@@ -268,21 +320,30 @@ def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
|||||||
return f"local://{wid}/{fid}/{fname}"
|
return f"local://{wid}/{fid}/{fname}"
|
||||||
|
|
||||||
|
|
||||||
async def upload_transcript(user_id: str, session_id: str, content: str) -> None:
|
async def upload_transcript(
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
content: str,
|
||||||
|
message_count: int = 0,
|
||||||
|
) -> None:
|
||||||
"""Strip progress entries and upload transcript to bucket storage.
|
"""Strip progress entries and upload transcript to bucket storage.
|
||||||
|
|
||||||
Safety: only overwrites when the new (stripped) transcript is larger than
|
Safety: only overwrites when the new (stripped) transcript is larger than
|
||||||
what is already stored. Since JSONL is append-only, the latest transcript
|
what is already stored. Since JSONL is append-only, the latest transcript
|
||||||
is always the longest. This prevents a slow/stale background task from
|
is always the longest. This prevents a slow/stale background task from
|
||||||
clobbering a newer upload from a concurrent turn.
|
clobbering a newer upload from a concurrent turn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_count: ``len(session.messages)`` at upload time — used by
|
||||||
|
the next turn to detect staleness and compress only the gap.
|
||||||
"""
|
"""
|
||||||
from backend.util.workspace_storage import get_workspace_storage
|
from backend.util.workspace_storage import get_workspace_storage
|
||||||
|
|
||||||
stripped = strip_progress_entries(content)
|
stripped = strip_progress_entries(content)
|
||||||
if not validate_transcript(stripped):
|
if not validate_transcript(stripped):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[Transcript] Skipping upload — stripped content is not a valid "
|
f"[Transcript] Skipping upload — stripped content not valid "
|
||||||
f"transcript for session {session_id}"
|
f"for session {session_id}"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -296,10 +357,9 @@ async def upload_transcript(user_id: str, session_id: str, content: str) -> None
|
|||||||
try:
|
try:
|
||||||
existing = await storage.retrieve(path)
|
existing = await storage.retrieve(path)
|
||||||
if len(existing) >= new_size:
|
if len(existing) >= new_size:
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[Transcript] Skipping upload — existing transcript "
|
f"[Transcript] Skipping upload — existing ({len(existing)}B) "
|
||||||
f"({len(existing)}B) >= new ({new_size}B) for session "
|
f">= new ({new_size}B) for session {session_id}"
|
||||||
f"{session_id}"
|
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
except (FileNotFoundError, Exception):
|
except (FileNotFoundError, Exception):
|
||||||
@@ -311,16 +371,38 @@ async def upload_transcript(user_id: str, session_id: str, content: str) -> None
|
|||||||
filename=fname,
|
filename=fname,
|
||||||
content=encoded,
|
content=encoded,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Store metadata alongside the transcript so the next turn can detect
|
||||||
|
# staleness and only compress the gap instead of the full history.
|
||||||
|
# Wrapped in try/except so a metadata write failure doesn't orphan
|
||||||
|
# the already-uploaded transcript — the next turn will just fall back
|
||||||
|
# to full gap fill (msg_count=0).
|
||||||
|
try:
|
||||||
|
meta = {"message_count": message_count, "uploaded_at": time.time()}
|
||||||
|
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||||
|
await storage.store(
|
||||||
|
workspace_id=mwid,
|
||||||
|
file_id=mfid,
|
||||||
|
filename=mfname,
|
||||||
|
content=json.dumps(meta).encode("utf-8"),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[Transcript] Failed to write metadata for {session_id}: {e}")
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[Transcript] Uploaded {new_size} bytes "
|
f"[Transcript] Uploaded {new_size}B "
|
||||||
f"(stripped from {len(content)}) for session {session_id}"
|
f"(stripped from {len(content)}B, msg_count={message_count}) "
|
||||||
|
f"for session {session_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def download_transcript(user_id: str, session_id: str) -> str | None:
|
async def download_transcript(
|
||||||
"""Download transcript from bucket storage.
|
user_id: str, session_id: str
|
||||||
|
) -> TranscriptDownload | None:
|
||||||
|
"""Download transcript and metadata from bucket storage.
|
||||||
|
|
||||||
Returns the JSONL content string, or ``None`` if not found.
|
Returns a ``TranscriptDownload`` with the JSONL content and the
|
||||||
|
``message_count`` watermark from the upload, or ``None`` if not found.
|
||||||
"""
|
"""
|
||||||
from backend.util.workspace_storage import get_workspace_storage
|
from backend.util.workspace_storage import get_workspace_storage
|
||||||
|
|
||||||
@@ -330,10 +412,6 @@ async def download_transcript(user_id: str, session_id: str) -> str | None:
|
|||||||
try:
|
try:
|
||||||
data = await storage.retrieve(path)
|
data = await storage.retrieve(path)
|
||||||
content = data.decode("utf-8")
|
content = data.decode("utf-8")
|
||||||
logger.info(
|
|
||||||
f"[Transcript] Downloaded {len(content)} bytes for session {session_id}"
|
|
||||||
)
|
|
||||||
return content
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
|
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
|
||||||
return None
|
return None
|
||||||
@@ -341,6 +419,36 @@ async def download_transcript(user_id: str, session_id: str) -> str | None:
|
|||||||
logger.warning(f"[Transcript] Failed to download transcript: {e}")
|
logger.warning(f"[Transcript] Failed to download transcript: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||||
|
message_count = 0
|
||||||
|
uploaded_at = 0.0
|
||||||
|
try:
|
||||||
|
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||||
|
|
||||||
|
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||||
|
if isinstance(storage, GCSWorkspaceStorage):
|
||||||
|
blob = f"workspaces/{mwid}/{mfid}/{mfname}"
|
||||||
|
meta_path = f"gcs://{storage.bucket_name}/{blob}"
|
||||||
|
else:
|
||||||
|
meta_path = f"local://{mwid}/{mfid}/{mfname}"
|
||||||
|
|
||||||
|
meta_data = await storage.retrieve(meta_path)
|
||||||
|
meta = json.loads(meta_data.decode("utf-8"))
|
||||||
|
message_count = meta.get("message_count", 0)
|
||||||
|
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||||
|
except (FileNotFoundError, json.JSONDecodeError, Exception):
|
||||||
|
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[Transcript] Downloaded {len(content)}B "
|
||||||
|
f"(msg_count={message_count}) for session {session_id}"
|
||||||
|
)
|
||||||
|
return TranscriptDownload(
|
||||||
|
content=content,
|
||||||
|
message_count=message_count,
|
||||||
|
uploaded_at=uploaded_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||||
"""Delete transcript from bucket storage (e.g. after resume failure)."""
|
"""Delete transcript from bucket storage (e.g. after resume failure)."""
|
||||||
|
|||||||
@@ -387,7 +387,7 @@ async def stream_chat_completion(
|
|||||||
if user_id:
|
if user_id:
|
||||||
log_meta["user_id"] = user_id
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
|
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
|
||||||
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
|
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
|
||||||
extra={
|
extra={
|
||||||
@@ -404,7 +404,7 @@ async def stream_chat_completion(
|
|||||||
fetch_start = time.monotonic()
|
fetch_start = time.monotonic()
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
fetch_time = (time.monotonic() - fetch_start) * 1000
|
fetch_time = (time.monotonic() - fetch_start) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
|
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
|
||||||
f"n_messages={len(session.messages) if session else 0}",
|
f"n_messages={len(session.messages) if session else 0}",
|
||||||
extra={
|
extra={
|
||||||
@@ -416,7 +416,7 @@ async def stream_chat_completion(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] Using provided session, messages={len(session.messages)}",
|
f"[TIMING] Using provided session, messages={len(session.messages)}",
|
||||||
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||||
)
|
)
|
||||||
@@ -450,7 +450,7 @@ async def stream_chat_completion(
|
|||||||
message_length=len(message),
|
message_length=len(message),
|
||||||
)
|
)
|
||||||
posthog_time = (time.monotonic() - posthog_start) * 1000
|
posthog_time = (time.monotonic() - posthog_start) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
|
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
|
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
|
||||||
)
|
)
|
||||||
@@ -458,7 +458,7 @@ async def stream_chat_completion(
|
|||||||
upsert_start = time.monotonic()
|
upsert_start = time.monotonic()
|
||||||
session = await upsert_chat_session(session)
|
session = await upsert_chat_session(session)
|
||||||
upsert_time = (time.monotonic() - upsert_start) * 1000
|
upsert_time = (time.monotonic() - upsert_start) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
|
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
|
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
|
||||||
)
|
)
|
||||||
@@ -503,7 +503,7 @@ async def stream_chat_completion(
|
|||||||
prompt_start = time.monotonic()
|
prompt_start = time.monotonic()
|
||||||
system_prompt, understanding = await _build_system_prompt(user_id)
|
system_prompt, understanding = await _build_system_prompt(user_id)
|
||||||
prompt_time = (time.monotonic() - prompt_start) * 1000
|
prompt_time = (time.monotonic() - prompt_start) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
|
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
|
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
|
||||||
)
|
)
|
||||||
@@ -537,7 +537,7 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
# Only yield message start for the initial call, not for continuations.
|
# Only yield message start for the initial call, not for continuations.
|
||||||
setup_time = (time.monotonic() - completion_start) * 1000
|
setup_time = (time.monotonic() - completion_start) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
|
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
|
||||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||||
)
|
)
|
||||||
@@ -548,7 +548,7 @@ async def stream_chat_completion(
|
|||||||
yield StreamStartStep()
|
yield StreamStartStep()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(
|
logger.debug(
|
||||||
"[TIMING] Calling _stream_chat_chunks",
|
"[TIMING] Calling _stream_chat_chunks",
|
||||||
extra={"json_fields": log_meta},
|
extra={"json_fields": log_meta},
|
||||||
)
|
)
|
||||||
@@ -988,7 +988,7 @@ async def _stream_chat_chunks(
|
|||||||
if session.user_id:
|
if session.user_id:
|
||||||
log_meta["user_id"] = session.user_id
|
log_meta["user_id"] = session.user_id
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
|
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
|
||||||
f"user={session.user_id}, n_messages={len(session.messages)}",
|
f"user={session.user_id}, n_messages={len(session.messages)}",
|
||||||
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||||
@@ -1011,7 +1011,7 @@ async def _stream_chat_chunks(
|
|||||||
base_url=config.base_url,
|
base_url=config.base_url,
|
||||||
)
|
)
|
||||||
context_time = (time_module.perf_counter() - context_start) * 1000
|
context_time = (time_module.perf_counter() - context_start) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
|
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
|
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
|
||||||
)
|
)
|
||||||
@@ -1053,7 +1053,7 @@ async def _stream_chat_chunks(
|
|||||||
retry_info = (
|
retry_info = (
|
||||||
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
|
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
|
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
@@ -1093,7 +1093,7 @@ async def _stream_chat_chunks(
|
|||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
|
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
|
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
|
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
|
||||||
)
|
)
|
||||||
@@ -1142,7 +1142,7 @@ async def _stream_chat_chunks(
|
|||||||
ttfc = (
|
ttfc = (
|
||||||
time_module.perf_counter() - api_call_start
|
time_module.perf_counter() - api_call_start
|
||||||
) * 1000
|
) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
|
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
|
||||||
f"(since API call), n_chunks={chunk_count}",
|
f"(since API call), n_chunks={chunk_count}",
|
||||||
extra={
|
extra={
|
||||||
@@ -1210,7 +1210,7 @@ async def _stream_chat_chunks(
|
|||||||
)
|
)
|
||||||
emitted_start_for_idx.add(idx)
|
emitted_start_for_idx.add(idx)
|
||||||
stream_duration = time_module.perf_counter() - api_call_start
|
stream_duration = time_module.perf_counter() - api_call_start
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
|
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
|
||||||
f"duration={stream_duration:.2f}s, "
|
f"duration={stream_duration:.2f}s, "
|
||||||
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
|
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
|
||||||
@@ -1244,7 +1244,7 @@ async def _stream_chat_chunks(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
|
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
|
||||||
f"session={session.session_id}, user={session.user_id}",
|
f"session={session.session_id}, user={session.user_id}",
|
||||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||||
@@ -1494,8 +1494,8 @@ async def _yield_tool_call(
|
|||||||
# Mark stream registry task as failed if it was created
|
# Mark stream registry task as failed if it was created
|
||||||
try:
|
try:
|
||||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||||
except Exception:
|
except Exception as mark_err:
|
||||||
pass
|
logger.warning(f"Failed to mark task {task_id} as failed: {mark_err}")
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
|||||||
"Transcript was not uploaded to bucket after turn 1 — "
|
"Transcript was not uploaded to bucket after turn 1 — "
|
||||||
"Stop hook may not have fired or transcript was too small"
|
"Stop hook may not have fired or transcript was too small"
|
||||||
)
|
)
|
||||||
logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes")
|
logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes")
|
||||||
|
|
||||||
# Reload session for turn 2
|
# Reload session for turn 2
|
||||||
session = await get_chat_session(session.session_id, test_user_id)
|
session = await get_chat_session(session.session_id, test_user_id)
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ async def create_task(
|
|||||||
if user_id:
|
if user_id:
|
||||||
log_meta["user_id"] = user_id
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
|
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||||
extra={"json_fields": log_meta},
|
extra={"json_fields": log_meta},
|
||||||
)
|
)
|
||||||
@@ -135,7 +135,7 @@ async def create_task(
|
|||||||
redis_start = time.perf_counter()
|
redis_start = time.perf_counter()
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
redis_time = (time.perf_counter() - redis_start) * 1000
|
redis_time = (time.perf_counter() - redis_start) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
|
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
|
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
|
||||||
)
|
)
|
||||||
@@ -158,7 +158,7 @@ async def create_task(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
hset_time = (time.perf_counter() - hset_start) * 1000
|
hset_time = (time.perf_counter() - hset_start) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] redis.hset took {hset_time:.1f}ms",
|
f"[TIMING] redis.hset took {hset_time:.1f}ms",
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
|
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
|
||||||
)
|
)
|
||||||
@@ -169,7 +169,7 @@ async def create_task(
|
|||||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||||
|
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
|
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
|
||||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||||
)
|
)
|
||||||
@@ -230,7 +230,7 @@ async def publish_chunk(
|
|||||||
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
|
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
|
||||||
or total_time > 50
|
or total_time > 50
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
|
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
@@ -279,7 +279,7 @@ async def subscribe_to_task(
|
|||||||
if user_id:
|
if user_id:
|
||||||
log_meta["user_id"] = user_id
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
|
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
|
||||||
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
|
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
|
||||||
)
|
)
|
||||||
@@ -289,14 +289,14 @@ async def subscribe_to_task(
|
|||||||
meta_key = _get_task_meta_key(task_id)
|
meta_key = _get_task_meta_key(task_id)
|
||||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||||
hgetall_time = (time.perf_counter() - redis_start) * 1000
|
hgetall_time = (time.perf_counter() - redis_start) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
|
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
||||||
)
|
)
|
||||||
|
|
||||||
if not meta:
|
if not meta:
|
||||||
elapsed = (time.perf_counter() - start_time) * 1000
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
|
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
@@ -335,7 +335,7 @@ async def subscribe_to_task(
|
|||||||
xread_start = time.perf_counter()
|
xread_start = time.perf_counter()
|
||||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
|
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
@@ -363,7 +363,7 @@ async def subscribe_to_task(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to replay message: {e}")
|
logger.warning(f"Failed to replay message: {e}")
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
|
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
@@ -376,7 +376,7 @@ async def subscribe_to_task(
|
|||||||
|
|
||||||
# Step 2: If task is still running, start stream listener for live updates
|
# Step 2: If task is still running, start stream listener for live updates
|
||||||
if task_status == "running":
|
if task_status == "running":
|
||||||
logger.info(
|
logger.debug(
|
||||||
"[TIMING] Task still running, starting _stream_listener",
|
"[TIMING] Task still running, starting _stream_listener",
|
||||||
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||||
)
|
)
|
||||||
@@ -387,14 +387,14 @@ async def subscribe_to_task(
|
|||||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||||
else:
|
else:
|
||||||
# Task is completed/failed - add finish marker
|
# Task is completed/failed - add finish marker
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] Task already {task_status}, adding StreamFinish",
|
f"[TIMING] Task already {task_status}, adding StreamFinish",
|
||||||
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||||
)
|
)
|
||||||
await subscriber_queue.put(StreamFinish())
|
await subscriber_queue.put(StreamFinish())
|
||||||
|
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
|
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
|
||||||
f"n_messages_replayed={replayed_count}",
|
f"n_messages_replayed={replayed_count}",
|
||||||
extra={
|
extra={
|
||||||
@@ -433,7 +433,7 @@ async def _stream_listener(
|
|||||||
if log_meta is None:
|
if log_meta is None:
|
||||||
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
|
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
|
||||||
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
|
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
|
||||||
)
|
)
|
||||||
@@ -462,7 +462,7 @@ async def _stream_listener(
|
|||||||
|
|
||||||
if messages:
|
if messages:
|
||||||
msg_count = sum(len(msgs) for _, msgs in messages)
|
msg_count = sum(len(msgs) for _, msgs in messages)
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
|
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
@@ -475,7 +475,7 @@ async def _stream_listener(
|
|||||||
)
|
)
|
||||||
elif xread_time > 1000:
|
elif xread_time > 1000:
|
||||||
# Only log timeouts (30s blocking)
|
# Only log timeouts (30s blocking)
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
|
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
@@ -526,7 +526,7 @@ async def _stream_listener(
|
|||||||
if first_message_time is None:
|
if first_message_time is None:
|
||||||
first_message_time = time.perf_counter()
|
first_message_time = time.perf_counter()
|
||||||
elapsed = (first_message_time - start_time) * 1000
|
elapsed = (first_message_time - start_time) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
@@ -568,7 +568,7 @@ async def _stream_listener(
|
|||||||
# Stop listening on finish
|
# Stop listening on finish
|
||||||
if isinstance(chunk, StreamFinish):
|
if isinstance(chunk, StreamFinish):
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
|
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
@@ -587,7 +587,7 @@ async def _stream_listener(
|
|||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
elapsed = (time.perf_counter() - start_time) * 1000
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
|
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
@@ -619,7 +619,7 @@ async def _stream_listener(
|
|||||||
finally:
|
finally:
|
||||||
# Clean up listener task mapping on exit
|
# Clean up listener task mapping on exit
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
|
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
|
||||||
f"delivered={messages_delivered}, xread_count={xread_count}",
|
f"delivered={messages_delivered}, xread_count={xread_count}",
|
||||||
extra={
|
extra={
|
||||||
@@ -829,10 +829,13 @@ async def get_active_task_for_session(
|
|||||||
)
|
)
|
||||||
await mark_task_completed(task_id, "failed")
|
await mark_task_completed(task_id, "failed")
|
||||||
continue
|
continue
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError) as exc:
|
||||||
pass
|
logger.warning(
|
||||||
|
f"[TASK_LOOKUP] Failed to parse created_at "
|
||||||
|
f"for task {task_id[:8]}...: {exc}"
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
|
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -312,8 +312,18 @@ class ReadWorkspaceFileTool(BaseTool):
|
|||||||
is_small_file = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
|
is_small_file = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
|
||||||
is_text_file = self._is_text_mime_type(file_info.mime_type)
|
is_text_file = self._is_text_mime_type(file_info.mime_type)
|
||||||
|
|
||||||
# Return inline content for small text files (unless force_download_url)
|
# Return inline content for small text/image files (unless force_download_url)
|
||||||
if is_small_file and is_text_file and not force_download_url:
|
is_image_file = file_info.mime_type in {
|
||||||
|
"image/png",
|
||||||
|
"image/jpeg",
|
||||||
|
"image/gif",
|
||||||
|
"image/webp",
|
||||||
|
}
|
||||||
|
if (
|
||||||
|
is_small_file
|
||||||
|
and (is_text_file or is_image_file)
|
||||||
|
and not force_download_url
|
||||||
|
):
|
||||||
content = await manager.read_file_by_id(target_file_id)
|
content = await manager.read_file_by_id(target_file_id)
|
||||||
content_b64 = base64.b64encode(content).decode("utf-8")
|
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||||
|
|
||||||
|
|||||||
@@ -599,6 +599,15 @@ def get_service_client(
|
|||||||
if error_response and error_response.type in EXCEPTION_MAPPING:
|
if error_response and error_response.type in EXCEPTION_MAPPING:
|
||||||
exception_class = EXCEPTION_MAPPING[error_response.type]
|
exception_class = EXCEPTION_MAPPING[error_response.type]
|
||||||
args = error_response.args or [str(e)]
|
args = error_response.args or [str(e)]
|
||||||
|
|
||||||
|
# Prisma DataError subclasses expect a dict `data` arg,
|
||||||
|
# but RPC serialization only preserves the string message
|
||||||
|
# from exc.args. Wrap it in the expected structure so
|
||||||
|
# the constructor doesn't crash on `.get()`.
|
||||||
|
if issubclass(exception_class, DataError):
|
||||||
|
msg = str(args[0]) if args else str(e)
|
||||||
|
raise exception_class({"user_facing_error": {"message": msg}})
|
||||||
|
|
||||||
raise exception_class(*args)
|
raise exception_class(*args)
|
||||||
|
|
||||||
# Otherwise categorize by HTTP status code
|
# Otherwise categorize by HTTP status code
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from unittest.mock import Mock
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
from prisma.errors import DataError, UniqueViolationError
|
||||||
|
|
||||||
from backend.util.service import (
|
from backend.util.service import (
|
||||||
AppService,
|
AppService,
|
||||||
@@ -447,6 +448,39 @@ class TestHTTPErrorRetryBehavior:
|
|||||||
|
|
||||||
assert "Invalid parameter value" in str(exc_info.value)
|
assert "Invalid parameter value" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_prisma_data_error_reconstructed_correctly(self):
|
||||||
|
"""Test that DataError subclasses (e.g. UniqueViolationError) are
|
||||||
|
reconstructed without crashing.
|
||||||
|
|
||||||
|
Prisma's DataError.__init__ expects a dict `data` arg with
|
||||||
|
a 'user_facing_error' key. RPC serialization only preserves the
|
||||||
|
string message via exc.args, so the client must wrap it in the
|
||||||
|
expected dict structure.
|
||||||
|
"""
|
||||||
|
for exc_type in [DataError, UniqueViolationError]:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status_code = 400
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"type": exc_type.__name__,
|
||||||
|
"args": ["Unique constraint failed on the fields: (`path`)"],
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||||
|
"400 Bad Request", request=Mock(), response=mock_response
|
||||||
|
)
|
||||||
|
|
||||||
|
client = get_service_client(ServiceTestClient)
|
||||||
|
|
||||||
|
with pytest.raises(exc_type) as exc_info:
|
||||||
|
client._handle_call_method_response( # type: ignore[attr-defined]
|
||||||
|
response=mock_response, method_name="test_method"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The exception should have the message preserved
|
||||||
|
assert "Unique constraint" in str(exc_info.value)
|
||||||
|
# And should have the expected data structure (not crash)
|
||||||
|
assert hasattr(exc_info.value, "data")
|
||||||
|
assert isinstance(exc_info.value.data, dict)
|
||||||
|
|
||||||
def test_client_error_status_codes_coverage(self):
|
def test_client_error_status_codes_coverage(self):
|
||||||
"""Test that various 4xx status codes are all wrapped as HTTPClientError."""
|
"""Test that various 4xx status codes are all wrapped as HTTPClientError."""
|
||||||
client_error_codes = [400, 401, 403, 404, 405, 409, 422, 429]
|
client_error_codes = [400, 401, 403, 404, 405, 409, 422, 429]
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ export function CopilotPage() {
|
|||||||
status,
|
status,
|
||||||
error,
|
error,
|
||||||
stop,
|
stop,
|
||||||
|
isReconnecting,
|
||||||
createSession,
|
createSession,
|
||||||
onSend,
|
onSend,
|
||||||
isLoadingSession,
|
isLoadingSession,
|
||||||
@@ -71,6 +72,7 @@ export function CopilotPage() {
|
|||||||
sessionId={sessionId}
|
sessionId={sessionId}
|
||||||
isLoadingSession={isLoadingSession}
|
isLoadingSession={isLoadingSession}
|
||||||
isCreatingSession={isCreatingSession}
|
isCreatingSession={isCreatingSession}
|
||||||
|
isReconnecting={isReconnecting}
|
||||||
onCreateSession={createSession}
|
onCreateSession={createSession}
|
||||||
onSend={onSend}
|
onSend={onSend}
|
||||||
onStop={stop}
|
onStop={stop}
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ export interface ChatContainerProps {
|
|||||||
sessionId: string | null;
|
sessionId: string | null;
|
||||||
isLoadingSession: boolean;
|
isLoadingSession: boolean;
|
||||||
isCreatingSession: boolean;
|
isCreatingSession: boolean;
|
||||||
|
/** True when backend has an active stream but we haven't reconnected yet. */
|
||||||
|
isReconnecting?: boolean;
|
||||||
onCreateSession: () => void | Promise<string>;
|
onCreateSession: () => void | Promise<string>;
|
||||||
onSend: (message: string) => void | Promise<void>;
|
onSend: (message: string) => void | Promise<void>;
|
||||||
onStop: () => void;
|
onStop: () => void;
|
||||||
@@ -26,11 +28,13 @@ export const ChatContainer = ({
|
|||||||
sessionId,
|
sessionId,
|
||||||
isLoadingSession,
|
isLoadingSession,
|
||||||
isCreatingSession,
|
isCreatingSession,
|
||||||
|
isReconnecting,
|
||||||
onCreateSession,
|
onCreateSession,
|
||||||
onSend,
|
onSend,
|
||||||
onStop,
|
onStop,
|
||||||
headerSlot,
|
headerSlot,
|
||||||
}: ChatContainerProps) => {
|
}: ChatContainerProps) => {
|
||||||
|
const isBusy = status === "streaming" || !!isReconnecting;
|
||||||
const inputLayoutId = "copilot-2-chat-input";
|
const inputLayoutId = "copilot-2-chat-input";
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -56,8 +60,8 @@ export const ChatContainer = ({
|
|||||||
<ChatInput
|
<ChatInput
|
||||||
inputId="chat-input-session"
|
inputId="chat-input-session"
|
||||||
onSend={onSend}
|
onSend={onSend}
|
||||||
disabled={status === "streaming"}
|
disabled={isBusy}
|
||||||
isStreaming={status === "streaming"}
|
isStreaming={isBusy}
|
||||||
onStop={onStop}
|
onStop={onStop}
|
||||||
placeholder="What else can I help with?"
|
placeholder="What else can I help with?"
|
||||||
/>
|
/>
|
||||||
|
|||||||
@@ -1,63 +1,713 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
|
import React from "react";
|
||||||
import { ToolUIPart } from "ai";
|
import { ToolUIPart } from "ai";
|
||||||
import { GearIcon } from "@phosphor-icons/react";
|
import {
|
||||||
|
CheckCircleIcon,
|
||||||
|
CircleDashedIcon,
|
||||||
|
CircleIcon,
|
||||||
|
FileIcon,
|
||||||
|
FilesIcon,
|
||||||
|
GearIcon,
|
||||||
|
GlobeIcon,
|
||||||
|
ListChecksIcon,
|
||||||
|
MagnifyingGlassIcon,
|
||||||
|
PencilSimpleIcon,
|
||||||
|
TerminalIcon,
|
||||||
|
TrashIcon,
|
||||||
|
WarningDiamondIcon,
|
||||||
|
} from "@phosphor-icons/react";
|
||||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||||
|
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
|
||||||
|
import {
|
||||||
|
ContentCodeBlock,
|
||||||
|
ContentMessage,
|
||||||
|
} from "../../components/ToolAccordion/AccordionContent";
|
||||||
|
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
part: ToolUIPart;
|
part: ToolUIPart;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
/* Tool name helpers */
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
|
||||||
function extractToolName(part: ToolUIPart): string {
|
function extractToolName(part: ToolUIPart): string {
|
||||||
// ToolUIPart.type is "tool-{name}", extract the name portion.
|
|
||||||
return part.type.replace(/^tool-/, "");
|
return part.type.replace(/^tool-/, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
function formatToolName(name: string): string {
|
function formatToolName(name: string): string {
|
||||||
// "search_docs" → "Search docs", "Read" → "Read"
|
|
||||||
return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase());
|
return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase());
|
||||||
}
|
}
|
||||||
|
|
||||||
function getAnimationText(part: ToolUIPart): string {
|
/* ------------------------------------------------------------------ */
|
||||||
const label = formatToolName(extractToolName(part));
|
/* Tool categorization */
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
|
||||||
switch (part.state) {
|
type ToolCategory =
|
||||||
case "input-streaming":
|
| "bash"
|
||||||
case "input-available":
|
| "web"
|
||||||
return `Running ${label}…`;
|
| "file-read"
|
||||||
case "output-available":
|
| "file-write"
|
||||||
return `${label} completed`;
|
| "file-delete"
|
||||||
case "output-error":
|
| "file-list"
|
||||||
return `${label} failed`;
|
| "search"
|
||||||
|
| "edit"
|
||||||
|
| "todo"
|
||||||
|
| "other";
|
||||||
|
|
||||||
|
function getToolCategory(toolName: string): ToolCategory {
|
||||||
|
switch (toolName) {
|
||||||
|
case "bash_exec":
|
||||||
|
return "bash";
|
||||||
|
case "web_fetch":
|
||||||
|
case "WebSearch":
|
||||||
|
case "WebFetch":
|
||||||
|
return "web";
|
||||||
|
case "read_workspace_file":
|
||||||
|
case "Read":
|
||||||
|
return "file-read";
|
||||||
|
case "write_workspace_file":
|
||||||
|
case "Write":
|
||||||
|
return "file-write";
|
||||||
|
case "delete_workspace_file":
|
||||||
|
return "file-delete";
|
||||||
|
case "list_workspace_files":
|
||||||
|
case "Glob":
|
||||||
|
return "file-list";
|
||||||
|
case "Grep":
|
||||||
|
return "search";
|
||||||
|
case "Edit":
|
||||||
|
return "edit";
|
||||||
|
case "TodoWrite":
|
||||||
|
return "todo";
|
||||||
default:
|
default:
|
||||||
return `Running ${label}…`;
|
return "other";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
/* Tool icon */
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
|
||||||
|
function ToolIcon({
|
||||||
|
category,
|
||||||
|
isStreaming,
|
||||||
|
isError,
|
||||||
|
}: {
|
||||||
|
category: ToolCategory;
|
||||||
|
isStreaming: boolean;
|
||||||
|
isError: boolean;
|
||||||
|
}) {
|
||||||
|
if (isError) {
|
||||||
|
return (
|
||||||
|
<WarningDiamondIcon size={14} weight="regular" className="text-red-500" />
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (isStreaming) {
|
||||||
|
return <OrbitLoader size={14} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
const iconClass = "text-neutral-400";
|
||||||
|
switch (category) {
|
||||||
|
case "bash":
|
||||||
|
return <TerminalIcon size={14} weight="regular" className={iconClass} />;
|
||||||
|
case "web":
|
||||||
|
return <GlobeIcon size={14} weight="regular" className={iconClass} />;
|
||||||
|
case "file-read":
|
||||||
|
return <FileIcon size={14} weight="regular" className={iconClass} />;
|
||||||
|
case "file-write":
|
||||||
|
return <FileIcon size={14} weight="regular" className={iconClass} />;
|
||||||
|
case "file-delete":
|
||||||
|
return <TrashIcon size={14} weight="regular" className={iconClass} />;
|
||||||
|
case "file-list":
|
||||||
|
return <FilesIcon size={14} weight="regular" className={iconClass} />;
|
||||||
|
case "search":
|
||||||
|
return (
|
||||||
|
<MagnifyingGlassIcon size={14} weight="regular" className={iconClass} />
|
||||||
|
);
|
||||||
|
case "edit":
|
||||||
|
return (
|
||||||
|
<PencilSimpleIcon size={14} weight="regular" className={iconClass} />
|
||||||
|
);
|
||||||
|
case "todo":
|
||||||
|
return (
|
||||||
|
<ListChecksIcon size={14} weight="regular" className={iconClass} />
|
||||||
|
);
|
||||||
|
default:
|
||||||
|
return <GearIcon size={14} weight="regular" className={iconClass} />;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
/* Accordion icon (larger, for the accordion header) */
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
|
||||||
|
function AccordionIcon({ category }: { category: ToolCategory }) {
|
||||||
|
switch (category) {
|
||||||
|
case "bash":
|
||||||
|
return <TerminalIcon size={32} weight="light" />;
|
||||||
|
case "web":
|
||||||
|
return <GlobeIcon size={32} weight="light" />;
|
||||||
|
case "file-read":
|
||||||
|
case "file-write":
|
||||||
|
return <FileIcon size={32} weight="light" />;
|
||||||
|
case "file-delete":
|
||||||
|
return <TrashIcon size={32} weight="light" />;
|
||||||
|
case "file-list":
|
||||||
|
return <FilesIcon size={32} weight="light" />;
|
||||||
|
case "search":
|
||||||
|
return <MagnifyingGlassIcon size={32} weight="light" />;
|
||||||
|
case "edit":
|
||||||
|
return <PencilSimpleIcon size={32} weight="light" />;
|
||||||
|
case "todo":
|
||||||
|
return <ListChecksIcon size={32} weight="light" />;
|
||||||
|
default:
|
||||||
|
return <GearIcon size={32} weight="light" />;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
/* Input extraction */
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
|
||||||
|
function getInputSummary(toolName: string, input: unknown): string | null {
|
||||||
|
if (!input || typeof input !== "object") return null;
|
||||||
|
const inp = input as Record<string, unknown>;
|
||||||
|
|
||||||
|
switch (toolName) {
|
||||||
|
case "bash_exec":
|
||||||
|
return typeof inp.command === "string" ? inp.command : null;
|
||||||
|
case "web_fetch":
|
||||||
|
case "WebFetch":
|
||||||
|
return typeof inp.url === "string" ? inp.url : null;
|
||||||
|
case "WebSearch":
|
||||||
|
return typeof inp.query === "string" ? inp.query : null;
|
||||||
|
case "read_workspace_file":
|
||||||
|
case "Read":
|
||||||
|
return (
|
||||||
|
(typeof inp.file_path === "string" ? inp.file_path : null) ??
|
||||||
|
(typeof inp.path === "string" ? inp.path : null)
|
||||||
|
);
|
||||||
|
case "write_workspace_file":
|
||||||
|
case "Write":
|
||||||
|
return (
|
||||||
|
(typeof inp.file_path === "string" ? inp.file_path : null) ??
|
||||||
|
(typeof inp.path === "string" ? inp.path : null)
|
||||||
|
);
|
||||||
|
case "delete_workspace_file":
|
||||||
|
return typeof inp.file_path === "string" ? inp.file_path : null;
|
||||||
|
case "Glob":
|
||||||
|
return typeof inp.pattern === "string" ? inp.pattern : null;
|
||||||
|
case "Grep":
|
||||||
|
return typeof inp.pattern === "string" ? inp.pattern : null;
|
||||||
|
case "Edit":
|
||||||
|
return typeof inp.file_path === "string" ? inp.file_path : null;
|
||||||
|
case "TodoWrite": {
|
||||||
|
// Extract the in-progress task name for the status line
|
||||||
|
const todos = Array.isArray(inp.todos) ? inp.todos : [];
|
||||||
|
const active = todos.find(
|
||||||
|
(t: Record<string, unknown>) => t.status === "in_progress",
|
||||||
|
);
|
||||||
|
if (active && typeof active.activeForm === "string")
|
||||||
|
return active.activeForm;
|
||||||
|
if (active && typeof active.content === "string") return active.content;
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function truncate(text: string, maxLen: number): string {
|
||||||
|
if (text.length <= maxLen) return text;
|
||||||
|
return text.slice(0, maxLen).trimEnd() + "…";
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
/* Animation text */
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
|
||||||
|
function getAnimationText(part: ToolUIPart, category: ToolCategory): string {
|
||||||
|
const toolName = extractToolName(part);
|
||||||
|
const summary = getInputSummary(toolName, part.input);
|
||||||
|
const shortSummary = summary ? truncate(summary, 60) : null;
|
||||||
|
|
||||||
|
switch (part.state) {
|
||||||
|
case "input-streaming":
|
||||||
|
case "input-available": {
|
||||||
|
switch (category) {
|
||||||
|
case "bash":
|
||||||
|
return shortSummary ? `Running: ${shortSummary}` : "Running command…";
|
||||||
|
case "web":
|
||||||
|
if (toolName === "WebSearch") {
|
||||||
|
return shortSummary
|
||||||
|
? `Searching "${shortSummary}"`
|
||||||
|
: "Searching the web…";
|
||||||
|
}
|
||||||
|
return shortSummary
|
||||||
|
? `Fetching ${shortSummary}`
|
||||||
|
: "Fetching web content…";
|
||||||
|
case "file-read":
|
||||||
|
return shortSummary ? `Reading ${shortSummary}` : "Reading file…";
|
||||||
|
case "file-write":
|
||||||
|
return shortSummary ? `Writing ${shortSummary}` : "Writing file…";
|
||||||
|
case "file-delete":
|
||||||
|
return shortSummary ? `Deleting ${shortSummary}` : "Deleting file…";
|
||||||
|
case "file-list":
|
||||||
|
return shortSummary ? `Listing ${shortSummary}` : "Listing files…";
|
||||||
|
case "search":
|
||||||
|
return shortSummary
|
||||||
|
? `Searching for "${shortSummary}"`
|
||||||
|
: "Searching…";
|
||||||
|
case "edit":
|
||||||
|
return shortSummary ? `Editing ${shortSummary}` : "Editing file…";
|
||||||
|
case "todo":
|
||||||
|
return shortSummary ? `${shortSummary}` : "Updating task list…";
|
||||||
|
default:
|
||||||
|
return `Running ${formatToolName(toolName)}…`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "output-available": {
|
||||||
|
switch (category) {
|
||||||
|
case "bash": {
|
||||||
|
const exitCode = getExitCode(part.output);
|
||||||
|
if (exitCode !== null && exitCode !== 0) {
|
||||||
|
return `Command exited with code ${exitCode}`;
|
||||||
|
}
|
||||||
|
return shortSummary ? `Ran: ${shortSummary}` : "Command completed";
|
||||||
|
}
|
||||||
|
case "web":
|
||||||
|
if (toolName === "WebSearch") {
|
||||||
|
return shortSummary
|
||||||
|
? `Searched "${shortSummary}"`
|
||||||
|
: "Web search completed";
|
||||||
|
}
|
||||||
|
return shortSummary
|
||||||
|
? `Fetched ${shortSummary}`
|
||||||
|
: "Fetched web content";
|
||||||
|
case "file-read":
|
||||||
|
return shortSummary ? `Read ${shortSummary}` : "File read completed";
|
||||||
|
case "file-write":
|
||||||
|
return shortSummary ? `Wrote ${shortSummary}` : "File written";
|
||||||
|
case "file-delete":
|
||||||
|
return shortSummary ? `Deleted ${shortSummary}` : "File deleted";
|
||||||
|
case "file-list":
|
||||||
|
return "Listed files";
|
||||||
|
case "search":
|
||||||
|
return shortSummary
|
||||||
|
? `Searched for "${shortSummary}"`
|
||||||
|
: "Search completed";
|
||||||
|
case "edit":
|
||||||
|
return shortSummary ? `Edited ${shortSummary}` : "Edit completed";
|
||||||
|
case "todo":
|
||||||
|
return "Updated task list";
|
||||||
|
default:
|
||||||
|
return `${formatToolName(toolName)} completed`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "output-error": {
|
||||||
|
switch (category) {
|
||||||
|
case "bash":
|
||||||
|
return "Command failed";
|
||||||
|
case "web":
|
||||||
|
return toolName === "WebSearch" ? "Search failed" : "Fetch failed";
|
||||||
|
default:
|
||||||
|
return `${formatToolName(toolName)} failed`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return `Running ${formatToolName(toolName)}…`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
/* Output parsing helpers */
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
|
||||||
|
function parseOutput(output: unknown): Record<string, unknown> | null {
|
||||||
|
if (!output) return null;
|
||||||
|
if (typeof output === "object") return output as Record<string, unknown>;
|
||||||
|
if (typeof output === "string") {
|
||||||
|
const trimmed = output.trim();
|
||||||
|
if (!trimmed) return null;
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(trimmed);
|
||||||
|
if (
|
||||||
|
typeof parsed === "object" &&
|
||||||
|
parsed !== null &&
|
||||||
|
!Array.isArray(parsed)
|
||||||
|
)
|
||||||
|
return parsed;
|
||||||
|
} catch {
|
||||||
|
// Return as a message wrapper for plain text output
|
||||||
|
return { _raw: trimmed };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract text from MCP-style content blocks.
|
||||||
|
* SDK built-in tools (WebSearch, etc.) may return `{content: [{type:"text", text:"..."}]}`.
|
||||||
|
*/
|
||||||
|
function extractMcpText(output: Record<string, unknown>): string | null {
|
||||||
|
if (Array.isArray(output.content)) {
|
||||||
|
const texts = (output.content as Array<Record<string, unknown>>)
|
||||||
|
.filter((b) => b.type === "text" && typeof b.text === "string")
|
||||||
|
.map((b) => b.text as string);
|
||||||
|
if (texts.length > 0) return texts.join("\n");
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getExitCode(output: unknown): number | null {
|
||||||
|
const parsed = parseOutput(output);
|
||||||
|
if (!parsed) return null;
|
||||||
|
if (typeof parsed.exit_code === "number") return parsed.exit_code;
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getStringField(
|
||||||
|
obj: Record<string, unknown>,
|
||||||
|
...keys: string[]
|
||||||
|
): string | null {
|
||||||
|
for (const key of keys) {
|
||||||
|
if (typeof obj[key] === "string" && obj[key].length > 0)
|
||||||
|
return obj[key] as string;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
/* Accordion content per tool category */
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
|
||||||
|
interface AccordionData {
|
||||||
|
title: string;
|
||||||
|
description?: string;
|
||||||
|
content: React.ReactNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getBashAccordionData(
|
||||||
|
input: unknown,
|
||||||
|
output: Record<string, unknown>,
|
||||||
|
): AccordionData {
|
||||||
|
const inp = (input && typeof input === "object" ? input : {}) as Record<
|
||||||
|
string,
|
||||||
|
unknown
|
||||||
|
>;
|
||||||
|
const command = typeof inp.command === "string" ? inp.command : "Command";
|
||||||
|
|
||||||
|
const stdout = getStringField(output, "stdout");
|
||||||
|
const stderr = getStringField(output, "stderr");
|
||||||
|
const exitCode =
|
||||||
|
typeof output.exit_code === "number" ? output.exit_code : null;
|
||||||
|
const timedOut = output.timed_out === true;
|
||||||
|
const message = getStringField(output, "message");
|
||||||
|
|
||||||
|
const title = timedOut
|
||||||
|
? "Command timed out"
|
||||||
|
: exitCode !== null && exitCode !== 0
|
||||||
|
? `Command failed (exit ${exitCode})`
|
||||||
|
: "Command output";
|
||||||
|
|
||||||
|
return {
|
||||||
|
title,
|
||||||
|
description: truncate(command, 80),
|
||||||
|
content: (
|
||||||
|
<div className="space-y-2">
|
||||||
|
{stdout && (
|
||||||
|
<div>
|
||||||
|
<p className="mb-1 text-xs font-medium text-slate-500">stdout</p>
|
||||||
|
<ContentCodeBlock>{truncate(stdout, 2000)}</ContentCodeBlock>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{stderr && (
|
||||||
|
<div>
|
||||||
|
<p className="mb-1 text-xs font-medium text-slate-500">stderr</p>
|
||||||
|
<ContentCodeBlock>{truncate(stderr, 1000)}</ContentCodeBlock>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{!stdout && !stderr && message && (
|
||||||
|
<ContentMessage>{message}</ContentMessage>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function getWebAccordionData(
|
||||||
|
input: unknown,
|
||||||
|
output: Record<string, unknown>,
|
||||||
|
): AccordionData {
|
||||||
|
const inp = (input && typeof input === "object" ? input : {}) as Record<
|
||||||
|
string,
|
||||||
|
unknown
|
||||||
|
>;
|
||||||
|
const url =
|
||||||
|
getStringField(inp as Record<string, unknown>, "url", "query") ??
|
||||||
|
"Web content";
|
||||||
|
|
||||||
|
// Try direct string fields first, then MCP content blocks, then raw JSON
|
||||||
|
let content = getStringField(output, "content", "text", "_raw");
|
||||||
|
if (!content) content = extractMcpText(output);
|
||||||
|
if (!content) {
|
||||||
|
// Fallback: render the raw JSON so the accordion isn't empty
|
||||||
|
try {
|
||||||
|
const raw = JSON.stringify(output, null, 2);
|
||||||
|
if (raw !== "{}") content = raw;
|
||||||
|
} catch {
|
||||||
|
/* empty */
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const statusCode =
|
||||||
|
typeof output.status_code === "number" ? output.status_code : null;
|
||||||
|
const message = getStringField(output, "message");
|
||||||
|
|
||||||
|
return {
|
||||||
|
title: statusCode
|
||||||
|
? `Response (${statusCode})`
|
||||||
|
: url
|
||||||
|
? "Web fetch"
|
||||||
|
: "Search results",
|
||||||
|
description: truncate(url, 80),
|
||||||
|
content: content ? (
|
||||||
|
<ContentCodeBlock>{truncate(content, 2000)}</ContentCodeBlock>
|
||||||
|
) : message ? (
|
||||||
|
<ContentMessage>{message}</ContentMessage>
|
||||||
|
) : Object.keys(output).length > 0 ? (
|
||||||
|
<ContentCodeBlock>
|
||||||
|
{truncate(JSON.stringify(output, null, 2), 2000)}
|
||||||
|
</ContentCodeBlock>
|
||||||
|
) : null,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function getFileAccordionData(
|
||||||
|
input: unknown,
|
||||||
|
output: Record<string, unknown>,
|
||||||
|
): AccordionData {
|
||||||
|
const inp = (input && typeof input === "object" ? input : {}) as Record<
|
||||||
|
string,
|
||||||
|
unknown
|
||||||
|
>;
|
||||||
|
const filePath =
|
||||||
|
getStringField(
|
||||||
|
inp as Record<string, unknown>,
|
||||||
|
"file_path",
|
||||||
|
"path",
|
||||||
|
"pattern",
|
||||||
|
) ?? "File";
|
||||||
|
const content = getStringField(output, "content", "text", "_raw");
|
||||||
|
const message = getStringField(output, "message");
|
||||||
|
// For Glob/list results, try to show file list
|
||||||
|
const files = Array.isArray(output.files)
|
||||||
|
? output.files.filter((f: unknown): f is string => typeof f === "string")
|
||||||
|
: null;
|
||||||
|
|
||||||
|
return {
|
||||||
|
title: message ?? "File output",
|
||||||
|
description: truncate(filePath, 80),
|
||||||
|
content: (
|
||||||
|
<div className="space-y-2">
|
||||||
|
{content && (
|
||||||
|
<ContentCodeBlock>{truncate(content, 2000)}</ContentCodeBlock>
|
||||||
|
)}
|
||||||
|
{files && files.length > 0 && (
|
||||||
|
<ContentCodeBlock>
|
||||||
|
{truncate(files.join("\n"), 2000)}
|
||||||
|
</ContentCodeBlock>
|
||||||
|
)}
|
||||||
|
{!content && !files && message && (
|
||||||
|
<ContentMessage>{message}</ContentMessage>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
interface TodoItem {
|
||||||
|
content: string;
|
||||||
|
status: "pending" | "in_progress" | "completed";
|
||||||
|
activeForm?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getTodoAccordionData(input: unknown): AccordionData {
|
||||||
|
const inp = (input && typeof input === "object" ? input : {}) as Record<
|
||||||
|
string,
|
||||||
|
unknown
|
||||||
|
>;
|
||||||
|
const todos: TodoItem[] = Array.isArray(inp.todos)
|
||||||
|
? inp.todos.filter(
|
||||||
|
(t: unknown): t is TodoItem =>
|
||||||
|
typeof t === "object" &&
|
||||||
|
t !== null &&
|
||||||
|
typeof (t as TodoItem).content === "string",
|
||||||
|
)
|
||||||
|
: [];
|
||||||
|
|
||||||
|
const completed = todos.filter((t) => t.status === "completed").length;
|
||||||
|
const total = todos.length;
|
||||||
|
|
||||||
|
return {
|
||||||
|
title: "Task list",
|
||||||
|
description: `${completed}/${total} completed`,
|
||||||
|
content: (
|
||||||
|
<div className="space-y-1 py-1">
|
||||||
|
{todos.map((todo, i) => (
|
||||||
|
<div key={i} className="flex items-start gap-2 text-xs">
|
||||||
|
<span className="mt-0.5 flex-shrink-0">
|
||||||
|
{todo.status === "completed" ? (
|
||||||
|
<CheckCircleIcon
|
||||||
|
size={14}
|
||||||
|
weight="fill"
|
||||||
|
className="text-green-500"
|
||||||
|
/>
|
||||||
|
) : todo.status === "in_progress" ? (
|
||||||
|
<CircleDashedIcon
|
||||||
|
size={14}
|
||||||
|
weight="bold"
|
||||||
|
className="text-blue-500"
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<CircleIcon
|
||||||
|
size={14}
|
||||||
|
weight="regular"
|
||||||
|
className="text-neutral-400"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</span>
|
||||||
|
<span
|
||||||
|
className={
|
||||||
|
todo.status === "completed"
|
||||||
|
? "text-muted-foreground line-through"
|
||||||
|
: todo.status === "in_progress"
|
||||||
|
? "font-medium text-foreground"
|
||||||
|
: "text-muted-foreground"
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{todo.content}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function getDefaultAccordionData(
|
||||||
|
output: Record<string, unknown>,
|
||||||
|
): AccordionData {
|
||||||
|
const message = getStringField(output, "message");
|
||||||
|
const raw = output._raw;
|
||||||
|
const mcpText = extractMcpText(output);
|
||||||
|
|
||||||
|
let displayContent: string;
|
||||||
|
if (typeof raw === "string") {
|
||||||
|
displayContent = raw;
|
||||||
|
} else if (mcpText) {
|
||||||
|
displayContent = mcpText;
|
||||||
|
} else if (message) {
|
||||||
|
displayContent = message;
|
||||||
|
} else {
|
||||||
|
try {
|
||||||
|
displayContent = JSON.stringify(output, null, 2);
|
||||||
|
} catch {
|
||||||
|
displayContent = String(output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
title: "Output",
|
||||||
|
description: message ?? undefined,
|
||||||
|
content: (
|
||||||
|
<ContentCodeBlock>{truncate(displayContent, 2000)}</ContentCodeBlock>
|
||||||
|
),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function getAccordionData(
|
||||||
|
category: ToolCategory,
|
||||||
|
input: unknown,
|
||||||
|
output: Record<string, unknown>,
|
||||||
|
): AccordionData {
|
||||||
|
switch (category) {
|
||||||
|
case "bash":
|
||||||
|
return getBashAccordionData(input, output);
|
||||||
|
case "web":
|
||||||
|
return getWebAccordionData(input, output);
|
||||||
|
case "file-read":
|
||||||
|
case "file-write":
|
||||||
|
case "file-delete":
|
||||||
|
case "file-list":
|
||||||
|
case "search":
|
||||||
|
case "edit":
|
||||||
|
return getFileAccordionData(input, output);
|
||||||
|
case "todo":
|
||||||
|
return getTodoAccordionData(input);
|
||||||
|
default:
|
||||||
|
return getDefaultAccordionData(output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
/* Component */
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
|
||||||
export function GenericTool({ part }: Props) {
|
export function GenericTool({ part }: Props) {
|
||||||
|
const toolName = extractToolName(part);
|
||||||
|
const category = getToolCategory(toolName);
|
||||||
const isStreaming =
|
const isStreaming =
|
||||||
part.state === "input-streaming" || part.state === "input-available";
|
part.state === "input-streaming" || part.state === "input-available";
|
||||||
const isError = part.state === "output-error";
|
const isError = part.state === "output-error";
|
||||||
|
const text = getAnimationText(part, category);
|
||||||
|
|
||||||
|
const output = parseOutput(part.output);
|
||||||
|
const hasOutput =
|
||||||
|
part.state === "output-available" &&
|
||||||
|
!!output &&
|
||||||
|
Object.keys(output).length > 0;
|
||||||
|
const hasError = isError && !!output;
|
||||||
|
|
||||||
|
// TodoWrite: always show accordion from input (the todo list lives in input)
|
||||||
|
const hasTodoInput =
|
||||||
|
category === "todo" &&
|
||||||
|
part.input &&
|
||||||
|
typeof part.input === "object" &&
|
||||||
|
Array.isArray((part.input as Record<string, unknown>).todos);
|
||||||
|
|
||||||
|
const showAccordion = hasOutput || hasError || hasTodoInput;
|
||||||
|
const accordionData = showAccordion
|
||||||
|
? getAccordionData(category, part.input, output ?? {})
|
||||||
|
: null;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="py-2">
|
<div className="py-2">
|
||||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||||
<GearIcon
|
<ToolIcon
|
||||||
size={14}
|
category={category}
|
||||||
weight="regular"
|
isStreaming={isStreaming}
|
||||||
className={
|
isError={isError}
|
||||||
isError
|
|
||||||
? "text-red-500"
|
|
||||||
: isStreaming
|
|
||||||
? "animate-spin text-neutral-500"
|
|
||||||
: "text-neutral-400"
|
|
||||||
}
|
|
||||||
/>
|
/>
|
||||||
<MorphingTextAnimation
|
<MorphingTextAnimation
|
||||||
text={getAnimationText(part)}
|
text={text}
|
||||||
className={isError ? "text-red-500" : undefined}
|
className={isError ? "text-red-500" : undefined}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{showAccordion && accordionData ? (
|
||||||
|
<ToolAccordion
|
||||||
|
icon={<AccordionIcon category={category} />}
|
||||||
|
title={accordionData.title}
|
||||||
|
description={accordionData.description}
|
||||||
|
titleClassName={isError ? "text-red-500" : undefined}
|
||||||
|
>
|
||||||
|
{accordionData.content}
|
||||||
|
</ToolAccordion>
|
||||||
|
) : null}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,6 +50,14 @@ export function useChatSession() {
|
|||||||
);
|
);
|
||||||
}, [sessionQuery.data, sessionId]);
|
}, [sessionQuery.data, sessionId]);
|
||||||
|
|
||||||
|
// Expose active_stream info so the caller can trigger manual resume
|
||||||
|
// after hydration completes (rather than relying on AI SDK's built-in
|
||||||
|
// resume which fires before hydration).
|
||||||
|
const hasActiveStream = useMemo(() => {
|
||||||
|
if (sessionQuery.data?.status !== 200) return false;
|
||||||
|
return !!sessionQuery.data.data.active_stream;
|
||||||
|
}, [sessionQuery.data]);
|
||||||
|
|
||||||
const { mutateAsync: createSessionMutation, isPending: isCreatingSession } =
|
const { mutateAsync: createSessionMutation, isPending: isCreatingSession } =
|
||||||
usePostV2CreateSession({
|
usePostV2CreateSession({
|
||||||
mutation: {
|
mutation: {
|
||||||
@@ -102,6 +110,7 @@ export function useChatSession() {
|
|||||||
sessionId,
|
sessionId,
|
||||||
setSessionId,
|
setSessionId,
|
||||||
hydratedMessages,
|
hydratedMessages,
|
||||||
|
hasActiveStream,
|
||||||
isLoadingSession: sessionQuery.isLoading,
|
isLoadingSession: sessionQuery.isLoading,
|
||||||
createSession,
|
createSession,
|
||||||
isCreatingSession,
|
isCreatingSession,
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ export function useCopilotPage() {
|
|||||||
sessionId,
|
sessionId,
|
||||||
setSessionId,
|
setSessionId,
|
||||||
hydratedMessages,
|
hydratedMessages,
|
||||||
|
hasActiveStream,
|
||||||
isLoadingSession,
|
isLoadingSession,
|
||||||
createSession,
|
createSession,
|
||||||
isCreatingSession,
|
isCreatingSession,
|
||||||
@@ -80,14 +81,31 @@ export function useCopilotPage() {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
// Resume: GET goes to the same URL as POST (backend uses
|
||||||
|
// method to distinguish). Override the default formula which
|
||||||
|
// would append /{chatId}/stream to the existing path.
|
||||||
|
prepareReconnectToStreamRequest: () => ({
|
||||||
|
api: `/api/chat/sessions/${sessionId}/stream`,
|
||||||
|
}),
|
||||||
})
|
})
|
||||||
: null,
|
: null,
|
||||||
[sessionId],
|
[sessionId],
|
||||||
);
|
);
|
||||||
|
|
||||||
const { messages, sendMessage, stop, status, error, setMessages } = useChat({
|
const {
|
||||||
|
messages,
|
||||||
|
sendMessage,
|
||||||
|
stop,
|
||||||
|
status,
|
||||||
|
error,
|
||||||
|
setMessages,
|
||||||
|
resumeStream,
|
||||||
|
} = useChat({
|
||||||
id: sessionId ?? undefined,
|
id: sessionId ?? undefined,
|
||||||
transport: transport ?? undefined,
|
transport: transport ?? undefined,
|
||||||
|
// Don't use resume: true — it fires before hydration completes, causing
|
||||||
|
// the hydrated messages to overwrite the resumed stream. Instead we
|
||||||
|
// call resumeStream() manually after hydration + active_stream detection.
|
||||||
});
|
});
|
||||||
|
|
||||||
// Abort the stream if the backend doesn't start sending data within 12s.
|
// Abort the stream if the backend doesn't start sending data within 12s.
|
||||||
@@ -108,13 +126,31 @@ export function useCopilotPage() {
|
|||||||
return () => clearTimeout(timer);
|
return () => clearTimeout(timer);
|
||||||
}, [status]);
|
}, [status]);
|
||||||
|
|
||||||
|
// Hydrate messages from the REST session endpoint.
|
||||||
|
// Skip hydration while streaming to avoid overwriting the live stream.
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!hydratedMessages || hydratedMessages.length === 0) return;
|
if (!hydratedMessages || hydratedMessages.length === 0) return;
|
||||||
|
if (status === "streaming" || status === "submitted") return;
|
||||||
setMessages((prev) => {
|
setMessages((prev) => {
|
||||||
if (prev.length >= hydratedMessages.length) return prev;
|
if (prev.length >= hydratedMessages.length) return prev;
|
||||||
return hydratedMessages;
|
return hydratedMessages;
|
||||||
});
|
});
|
||||||
}, [hydratedMessages, setMessages]);
|
}, [hydratedMessages, setMessages, status]);
|
||||||
|
|
||||||
|
// Resume an active stream AFTER hydration completes.
|
||||||
|
// The backend returns active_stream info when a task is still running.
|
||||||
|
// We wait for hydration so the AI SDK has the conversation history
|
||||||
|
// before the resumed stream appends the in-progress assistant message.
|
||||||
|
const hasResumedRef = useRef<string | null>(null);
|
||||||
|
useEffect(() => {
|
||||||
|
if (!hasActiveStream || !sessionId) return;
|
||||||
|
if (!hydratedMessages || hydratedMessages.length === 0) return;
|
||||||
|
if (status === "streaming" || status === "submitted") return;
|
||||||
|
// Only resume once per session to avoid re-triggering after stream ends
|
||||||
|
if (hasResumedRef.current === sessionId) return;
|
||||||
|
hasResumedRef.current = sessionId;
|
||||||
|
resumeStream();
|
||||||
|
}, [hasActiveStream, sessionId, hydratedMessages, status, resumeStream]);
|
||||||
|
|
||||||
// Poll session endpoint when a long-running tool (create_agent, edit_agent)
|
// Poll session endpoint when a long-running tool (create_agent, edit_agent)
|
||||||
// is in progress. When the backend completes, the session data will contain
|
// is in progress. When the backend completes, the session data will contain
|
||||||
@@ -197,12 +233,18 @@ export function useCopilotPage() {
|
|||||||
}
|
}
|
||||||
}, [isDeleting]);
|
}, [isDeleting]);
|
||||||
|
|
||||||
|
// True while we know the backend has an active stream but haven't
|
||||||
|
// reconnected yet. Used to disable the send button and show stop UI.
|
||||||
|
const isReconnecting =
|
||||||
|
hasActiveStream && status !== "streaming" && status !== "submitted";
|
||||||
|
|
||||||
return {
|
return {
|
||||||
sessionId,
|
sessionId,
|
||||||
messages,
|
messages,
|
||||||
status,
|
status,
|
||||||
error,
|
error,
|
||||||
stop,
|
stop,
|
||||||
|
isReconnecting,
|
||||||
isLoadingSession,
|
isLoadingSession,
|
||||||
isCreatingSession,
|
isCreatingSession,
|
||||||
isUserLoading,
|
isUserLoading,
|
||||||
|
|||||||
Reference in New Issue
Block a user