mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-18 18:44:42 -05:00
Compare commits
7 Commits
otto/secrt
...
copilot/sd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b6128dd75f | ||
|
|
c4f5f7c8b8 | ||
|
|
8af4e0bf7d | ||
|
|
dc77e7b6e6 | ||
|
|
ba75cc28b5 | ||
|
|
15bcdae4e8 | ||
|
|
e9ba7e51db |
@@ -4,7 +4,6 @@ This module contains the CoPilotExecutor class that consumes chat tasks from
|
||||
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
@@ -164,21 +163,23 @@ class CoPilotExecutor(AppProcess):
|
||||
self._cancel_thread, self.cancel_client, "[cleanup][cancel]"
|
||||
)
|
||||
|
||||
# Shutdown executor
|
||||
# Clean up worker threads (closes per-loop workspace storage sessions)
|
||||
if self._executor:
|
||||
from .processor import cleanup_worker
|
||||
|
||||
logger.info(f"[cleanup {pid}] Cleaning up workers...")
|
||||
futures = []
|
||||
for _ in range(self._executor._max_workers):
|
||||
futures.append(self._executor.submit(cleanup_worker))
|
||||
for f in futures:
|
||||
try:
|
||||
f.result(timeout=10)
|
||||
except Exception as e:
|
||||
logger.warning(f"[cleanup {pid}] Worker cleanup error: {e}")
|
||||
|
||||
logger.info(f"[cleanup {pid}] Shutting down executor...")
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
# Close async resources (workspace storage aiohttp session, etc.)
|
||||
try:
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(shutdown_workspace_storage())
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[cleanup {pid}] Error closing workspace storage: {e}")
|
||||
|
||||
# Release any remaining locks
|
||||
for task_id, lock in list(self._task_locks.items()):
|
||||
try:
|
||||
|
||||
@@ -60,6 +60,18 @@ def init_worker():
|
||||
_tls.processor.on_executor_start()
|
||||
|
||||
|
||||
def cleanup_worker():
|
||||
"""Clean up the processor for the current worker thread.
|
||||
|
||||
Should be called before the worker thread's event loop is destroyed so
|
||||
that event-loop-bound resources (e.g. ``aiohttp.ClientSession``) are
|
||||
closed on the correct loop.
|
||||
"""
|
||||
processor: CoPilotProcessor | None = getattr(_tls, "processor", None)
|
||||
if processor is not None:
|
||||
processor.cleanup()
|
||||
|
||||
|
||||
# ============ Processor Class ============ #
|
||||
|
||||
|
||||
@@ -98,6 +110,28 @@ class CoPilotProcessor:
|
||||
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up event-loop-bound resources before the loop is destroyed.
|
||||
|
||||
Shuts down the workspace storage instance that belongs to this
|
||||
worker's event loop, ensuring ``aiohttp.ClientSession.close()``
|
||||
runs on the same loop that created the session.
|
||||
"""
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
shutdown_workspace_storage(), self.execution_loop
|
||||
)
|
||||
future.result(timeout=5)
|
||||
except Exception as e:
|
||||
logger.warning(f"[CoPilotExecutor] Worker {self.tid} cleanup error: {e}")
|
||||
|
||||
# Stop the event loop
|
||||
self.execution_loop.call_soon_threadsafe(self.execution_loop.stop)
|
||||
self.execution_thread.join(timeout=5)
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} cleaned up")
|
||||
|
||||
@error_logged(swallow=False)
|
||||
def execute(
|
||||
self,
|
||||
|
||||
@@ -53,6 +53,7 @@ class SDKResponseAdapter:
|
||||
self.has_started_text = False
|
||||
self.has_ended_text = False
|
||||
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
||||
self.resolved_tool_calls: set[str] = set()
|
||||
self.task_id: str | None = None
|
||||
self.step_open = False
|
||||
|
||||
@@ -74,6 +75,10 @@ class SDKResponseAdapter:
|
||||
self.step_open = True
|
||||
|
||||
elif isinstance(sdk_message, AssistantMessage):
|
||||
# Flush any SDK built-in tool calls that didn't get a UserMessage
|
||||
# result (e.g. WebSearch, Read handled internally by the CLI).
|
||||
self._flush_unresolved_tool_calls(responses)
|
||||
|
||||
# After tool results, the SDK sends a new AssistantMessage for the
|
||||
# next LLM turn. Open a new step if the previous one was closed.
|
||||
if not self.step_open:
|
||||
@@ -111,6 +116,8 @@ class SDKResponseAdapter:
|
||||
# UserMessage carries tool results back from tool execution.
|
||||
content = sdk_message.content
|
||||
blocks = content if isinstance(content, list) else []
|
||||
resolved_in_blocks: set[str] = set()
|
||||
|
||||
for block in blocks:
|
||||
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
||||
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
|
||||
@@ -132,6 +139,37 @@ class SDKResponseAdapter:
|
||||
success=not (block.is_error or False),
|
||||
)
|
||||
)
|
||||
resolved_in_blocks.add(block.tool_use_id)
|
||||
|
||||
# Handle SDK built-in tool results carried via parent_tool_use_id
|
||||
# instead of (or in addition to) ToolResultBlock content.
|
||||
parent_id = sdk_message.parent_tool_use_id
|
||||
if parent_id and parent_id not in resolved_in_blocks:
|
||||
tool_info = self.current_tool_calls.get(parent_id, {})
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
|
||||
# Try stashed output first (from PostToolUse hook),
|
||||
# then tool_use_result dict, then string content.
|
||||
output = pop_pending_tool_output(tool_name)
|
||||
if not output:
|
||||
tur = sdk_message.tool_use_result
|
||||
if tur is not None:
|
||||
output = _extract_tool_use_result(tur)
|
||||
if not output and isinstance(content, str) and content.strip():
|
||||
output = content.strip()
|
||||
|
||||
if output:
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=parent_id,
|
||||
toolName=tool_name,
|
||||
output=output,
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
resolved_in_blocks.add(parent_id)
|
||||
|
||||
self.resolved_tool_calls.update(resolved_in_blocks)
|
||||
|
||||
# Close the current step after tool results — the next
|
||||
# AssistantMessage will open a new step for the continuation.
|
||||
@@ -140,6 +178,7 @@ class SDKResponseAdapter:
|
||||
self.step_open = False
|
||||
|
||||
elif isinstance(sdk_message, ResultMessage):
|
||||
self._flush_unresolved_tool_calls(responses)
|
||||
self._end_text_if_open(responses)
|
||||
# Close the step before finishing.
|
||||
if self.step_open:
|
||||
@@ -149,7 +188,7 @@ class SDKResponseAdapter:
|
||||
if sdk_message.subtype == "success":
|
||||
responses.append(StreamFinish())
|
||||
elif sdk_message.subtype in ("error", "error_during_execution"):
|
||||
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
|
||||
error_msg = sdk_message.result or "Unknown error"
|
||||
responses.append(
|
||||
StreamError(errorText=str(error_msg), code="sdk_error")
|
||||
)
|
||||
@@ -180,6 +219,59 @@ class SDKResponseAdapter:
|
||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||
self.has_ended_text = True
|
||||
|
||||
def _flush_unresolved_tool_calls(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""Emit outputs for tool calls that didn't receive a UserMessage result.
|
||||
|
||||
SDK built-in tools (WebSearch, Read, etc.) may be executed by the CLI
|
||||
internally without surfacing a separate ``UserMessage`` with
|
||||
``ToolResultBlock`` content. The ``PostToolUse`` hook stashes their
|
||||
output, which we pop and emit here before the next ``AssistantMessage``
|
||||
starts.
|
||||
"""
|
||||
flushed = False
|
||||
for tool_id, tool_info in self.current_tool_calls.items():
|
||||
if tool_id in self.resolved_tool_calls:
|
||||
continue
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
output = pop_pending_tool_output(tool_name)
|
||||
if output is not None:
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_id,
|
||||
toolName=tool_name,
|
||||
output=output,
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
self.resolved_tool_calls.add(tool_id)
|
||||
flushed = True
|
||||
logger.debug(
|
||||
f"Flushed pending output for built-in tool {tool_name} "
|
||||
f"(call {tool_id})"
|
||||
)
|
||||
else:
|
||||
# No output available — emit an empty output so the frontend
|
||||
# transitions the tool from input-available to output-available
|
||||
# (stops the spinner).
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_id,
|
||||
toolName=tool_name,
|
||||
output="",
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
self.resolved_tool_calls.add(tool_id)
|
||||
flushed = True
|
||||
logger.debug(
|
||||
f"Flushed empty output for unresolved tool {tool_name} "
|
||||
f"(call {tool_id})"
|
||||
)
|
||||
|
||||
if flushed and self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
|
||||
|
||||
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
||||
"""Extract a string output from a ToolResultBlock's content field."""
|
||||
@@ -199,3 +291,30 @@ def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
||||
return json.dumps(content)
|
||||
except (TypeError, ValueError):
|
||||
return str(content)
|
||||
|
||||
|
||||
def _extract_tool_use_result(result: object) -> str:
|
||||
"""Extract a string from a UserMessage's ``tool_use_result`` dict.
|
||||
|
||||
SDK built-in tools may store their result in ``tool_use_result``
|
||||
instead of (or in addition to) ``ToolResultBlock`` content blocks.
|
||||
"""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
if isinstance(result, dict):
|
||||
# Try common result keys
|
||||
for key in ("content", "text", "output", "stdout", "result"):
|
||||
val = result.get(key)
|
||||
if isinstance(val, str) and val:
|
||||
return val
|
||||
# Fall back to JSON serialization of the whole dict
|
||||
try:
|
||||
return json.dumps(result)
|
||||
except (TypeError, ValueError):
|
||||
return str(result)
|
||||
if result is None:
|
||||
return ""
|
||||
try:
|
||||
return json.dumps(result)
|
||||
except (TypeError, ValueError):
|
||||
return str(result)
|
||||
|
||||
@@ -16,6 +16,7 @@ from .tool_adapter import (
|
||||
DANGEROUS_PATTERNS,
|
||||
MCP_TOOL_PREFIX,
|
||||
WORKSPACE_SCOPED_TOOLS,
|
||||
stash_pending_tool_output,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -224,10 +225,25 @@ def create_security_hooks(
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log successful tool executions for observability."""
|
||||
"""Log successful tool executions and stash SDK built-in tool outputs.
|
||||
|
||||
MCP tools stash their output in ``_execute_tool_sync`` before the
|
||||
SDK can truncate it. SDK built-in tools (WebSearch, Read, etc.)
|
||||
are executed by the CLI internally — this hook captures their
|
||||
output so the response adapter can forward it to the frontend.
|
||||
"""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
||||
|
||||
# Stash output for SDK built-in tools so the response adapter can
|
||||
# emit StreamToolOutputAvailable even when the CLI doesn't surface
|
||||
# a separate UserMessage with ToolResultBlock content.
|
||||
if not tool_name.startswith(MCP_TOOL_PREFIX):
|
||||
tool_response = input_data.get("tool_response")
|
||||
if tool_response is not None:
|
||||
stash_pending_tool_output(tool_name, tool_response)
|
||||
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def post_tool_failure_hook(
|
||||
|
||||
@@ -47,6 +47,7 @@ from .tool_adapter import (
|
||||
set_execution_context,
|
||||
)
|
||||
from .transcript import (
|
||||
cleanup_cli_project_dir,
|
||||
download_transcript,
|
||||
read_transcript_file,
|
||||
upload_transcript,
|
||||
@@ -86,9 +87,12 @@ _SDK_TOOL_SUPPLEMENT = """
|
||||
for shell commands — it runs in a network-isolated sandbox.
|
||||
- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the
|
||||
same working directory. Files created by one are readable by the other.
|
||||
These files are **ephemeral** — they exist only for the current session.
|
||||
- **Persistent storage**: Use `write_workspace_file` / `read_workspace_file`
|
||||
for files that should persist across sessions (stored in cloud storage).
|
||||
- **IMPORTANT — File persistence**: Your working directory is **ephemeral** —
|
||||
files are lost between turns. When you create or modify important files
|
||||
(code, configs, outputs), you MUST save them using `write_workspace_file`
|
||||
so they persist. Use `read_workspace_file` and `list_workspace_files` to
|
||||
access files saved in previous turns. If a "Files from previous turns"
|
||||
section is present above, those files are available via `read_workspace_file`.
|
||||
- Long-running tools (create_agent, edit_agent, etc.) are handled
|
||||
asynchronously. You will receive an immediate response; the actual result
|
||||
is delivered to the user via a background stream.
|
||||
@@ -268,48 +272,28 @@ def _make_sdk_cwd(session_id: str) -> str:
|
||||
|
||||
|
||||
def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||
"""Remove SDK tool-result files for a specific session working directory.
|
||||
"""Remove SDK session artifacts for a specific working directory.
|
||||
|
||||
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
|
||||
We clean only the specific cwd's results to avoid race conditions between
|
||||
concurrent sessions.
|
||||
Cleans up:
|
||||
- ``~/.claude/projects/<encoded-cwd>/`` — CLI session transcripts and
|
||||
tool-result files. Each SDK turn uses a unique cwd, so this directory
|
||||
is safe to remove entirely.
|
||||
- ``/tmp/copilot-<session>/`` — the ephemeral working directory.
|
||||
|
||||
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id.
|
||||
Security: *cwd* MUST be created by ``_make_sdk_cwd()`` which sanitizes
|
||||
the session_id.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
# Validate cwd is under the expected prefix
|
||||
normalized = os.path.normpath(cwd)
|
||||
if not normalized.startswith(_SDK_CWD_PREFIX):
|
||||
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
|
||||
return
|
||||
|
||||
# SDK encodes the cwd path by replacing '/' with '-'
|
||||
encoded_cwd = normalized.replace("/", "-")
|
||||
# Clean the CLI's project directory (transcripts + tool-results).
|
||||
cleanup_cli_project_dir(cwd)
|
||||
|
||||
# Construct the project directory path (known-safe home expansion)
|
||||
claude_projects = os.path.expanduser("~/.claude/projects")
|
||||
project_dir = os.path.join(claude_projects, encoded_cwd)
|
||||
|
||||
# Security check 3: Validate project_dir is under ~/.claude/projects
|
||||
project_dir = os.path.normpath(project_dir)
|
||||
if not project_dir.startswith(claude_projects):
|
||||
logger.warning(
|
||||
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
|
||||
)
|
||||
return
|
||||
|
||||
results_dir = os.path.join(project_dir, "tool-results")
|
||||
if os.path.isdir(results_dir):
|
||||
for filename in os.listdir(results_dir):
|
||||
file_path = os.path.join(results_dir, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Also clean up the temp cwd directory itself
|
||||
# Clean up the temp cwd directory itself.
|
||||
try:
|
||||
shutil.rmtree(normalized, ignore_errors=True)
|
||||
except OSError:
|
||||
@@ -519,6 +503,7 @@ async def stream_chat_completion_sdk(
|
||||
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
|
||||
captured_transcript.path = transcript_path
|
||||
captured_transcript.sdk_session_id = sdk_session_id
|
||||
logger.debug(f"[SDK] Stop hook: path={transcript_path!r}")
|
||||
|
||||
security_hooks = create_security_hooks(
|
||||
user_id,
|
||||
@@ -530,18 +515,20 @@ async def stream_chat_completion_sdk(
|
||||
# --- Resume strategy: download transcript from bucket ---
|
||||
resume_file: str | None = None
|
||||
use_resume = False
|
||||
transcript_msg_count = 0 # watermark: session.messages length at upload
|
||||
|
||||
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
transcript_content = await download_transcript(user_id, session_id)
|
||||
if transcript_content and validate_transcript(transcript_content):
|
||||
dl = await download_transcript(user_id, session_id)
|
||||
if dl and validate_transcript(dl.content):
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
transcript_content, session_id, sdk_cwd
|
||||
dl.content, session_id, sdk_cwd
|
||||
)
|
||||
if resume_file:
|
||||
use_resume = True
|
||||
logger.info(
|
||||
f"[SDK] Using --resume with transcript "
|
||||
f"({len(transcript_content)} bytes)"
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.debug(
|
||||
f"[SDK] Using --resume ({len(dl.content)}B, "
|
||||
f"msg_count={transcript_msg_count})"
|
||||
)
|
||||
|
||||
sdk_options_kwargs: dict[str, Any] = {
|
||||
@@ -582,11 +569,35 @@ async def stream_chat_completion_sdk(
|
||||
# Build query: with --resume the CLI already has full
|
||||
# context, so we only send the new message. Without
|
||||
# resume, compress history into a context prefix.
|
||||
#
|
||||
# Hybrid mode: if the transcript is stale (upload missed
|
||||
# some turns), compress only the gap and prepend it so
|
||||
# the agent has transcript context + missed turns.
|
||||
query_message = current_message
|
||||
if not use_resume and len(session.messages) > 1:
|
||||
current_msg_count = len(session.messages)
|
||||
|
||||
if use_resume and transcript_msg_count > 0:
|
||||
# Transcript covers messages[0..M-1]. Current session
|
||||
# has N messages (last one is the new user msg).
|
||||
# Gap = messages[M .. N-2] (everything between upload
|
||||
# and the current turn).
|
||||
if transcript_msg_count < current_msg_count - 1:
|
||||
gap = session.messages[transcript_msg_count:-1]
|
||||
gap_context = _format_conversation_context(gap)
|
||||
if gap_context:
|
||||
logger.info(
|
||||
f"[SDK] Transcript stale: covers {transcript_msg_count} "
|
||||
f"of {current_msg_count} messages, compressing "
|
||||
f"{len(gap)} missed messages"
|
||||
)
|
||||
query_message = (
|
||||
f"{gap_context}\n\n"
|
||||
f"Now, the user says:\n{current_message}"
|
||||
)
|
||||
elif not use_resume and current_msg_count > 1:
|
||||
logger.warning(
|
||||
f"[SDK] Using compression fallback for session "
|
||||
f"{session_id} ({len(session.messages)} messages) — "
|
||||
f"{session_id} ({current_msg_count} messages) — "
|
||||
f"no transcript available for --resume"
|
||||
)
|
||||
compressed = await _compress_conversation_history(session)
|
||||
@@ -597,10 +608,10 @@ async def stream_chat_completion_sdk(
|
||||
f"Now, the user says:\n{current_message}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[SDK] Sending query ({len(session.messages)} msgs in session)"
|
||||
logger.debug(
|
||||
f"[SDK] Sending query ({len(session.messages)} msgs, "
|
||||
f"resume={use_resume})"
|
||||
)
|
||||
logger.debug(f"[SDK] Query preview: {current_message[:80]!r}")
|
||||
await client.query(query_message, session_id=session_id)
|
||||
|
||||
assistant_response = ChatMessage(role="assistant", content="")
|
||||
@@ -681,25 +692,27 @@ async def stream_chat_completion_sdk(
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
# --- Capture transcript while CLI is still alive ---
|
||||
# Must happen INSIDE async with: close() sends SIGTERM
|
||||
# which kills the CLI before it can flush the JSONL.
|
||||
if (
|
||||
config.claude_agent_use_resume
|
||||
and user_id
|
||||
and captured_transcript.available
|
||||
):
|
||||
# Give CLI time to flush JSONL writes before we read
|
||||
await asyncio.sleep(0.5)
|
||||
# --- Upload transcript for next-turn --resume ---
|
||||
# After async with the SDK task group has exited, so the Stop
|
||||
# hook has already fired and the CLI has been SIGTERMed. The
|
||||
# CLI uses appendFileSync, so all writes are safely on disk.
|
||||
if config.claude_agent_use_resume and user_id:
|
||||
# With --resume the CLI appends to the resume file (most
|
||||
# complete). Otherwise use the Stop hook path.
|
||||
if use_resume and resume_file:
|
||||
raw_transcript = read_transcript_file(resume_file)
|
||||
elif captured_transcript.path:
|
||||
raw_transcript = read_transcript_file(captured_transcript.path)
|
||||
if raw_transcript:
|
||||
task = asyncio.create_task(
|
||||
_upload_transcript_bg(user_id, session_id, raw_transcript)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
else:
|
||||
logger.debug("[SDK] Stop hook fired but transcript not usable")
|
||||
else:
|
||||
raw_transcript = None
|
||||
|
||||
if raw_transcript:
|
||||
await _try_upload_transcript(
|
||||
user_id,
|
||||
session_id,
|
||||
raw_transcript,
|
||||
message_count=len(session.messages),
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
@@ -731,14 +744,31 @@ async def stream_chat_completion_sdk(
|
||||
_cleanup_sdk_tool_results(sdk_cwd)
|
||||
|
||||
|
||||
async def _upload_transcript_bg(
|
||||
user_id: str, session_id: str, raw_content: str
|
||||
) -> None:
|
||||
"""Background task to strip progress entries and upload transcript."""
|
||||
async def _try_upload_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
raw_content: str,
|
||||
message_count: int = 0,
|
||||
) -> bool:
|
||||
"""Strip progress entries and upload transcript (with timeout).
|
||||
|
||||
Returns True if the upload completed without error.
|
||||
"""
|
||||
try:
|
||||
await upload_transcript(user_id, session_id, raw_content)
|
||||
async with asyncio.timeout(30):
|
||||
await upload_transcript(
|
||||
user_id, session_id, raw_content, message_count=message_count
|
||||
)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[SDK] Transcript upload timed out for {session_id}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[SDK] Failed to upload transcript for {session_id}: {e}")
|
||||
logger.error(
|
||||
f"[SDK] Failed to upload transcript for {session_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
|
||||
@@ -103,6 +103,31 @@ def pop_pending_tool_output(tool_name: str) -> str | None:
|
||||
return pending.pop(tool_name, None)
|
||||
|
||||
|
||||
def stash_pending_tool_output(tool_name: str, output: Any) -> None:
|
||||
"""Stash tool output for later retrieval by the response adapter.
|
||||
|
||||
Used by the PostToolUse hook to capture SDK built-in tool outputs
|
||||
(WebSearch, Read, etc.) that aren't available through the MCP stash
|
||||
mechanism in ``_execute_tool_sync``.
|
||||
|
||||
Does NOT overwrite an existing stash entry — MCP tools stash in
|
||||
``_execute_tool_sync`` first with the guaranteed-full output.
|
||||
"""
|
||||
pending = _pending_tool_outputs.get(None)
|
||||
if pending is None:
|
||||
return
|
||||
# Don't overwrite MCP tool stash (which has the guaranteed-full output)
|
||||
if tool_name in pending:
|
||||
return
|
||||
if isinstance(output, str):
|
||||
pending[tool_name] = output
|
||||
else:
|
||||
try:
|
||||
pending[tool_name] = json.dumps(output)
|
||||
except (TypeError, ValueError):
|
||||
pending[tool_name] = str(output)
|
||||
|
||||
|
||||
async def _execute_tool_sync(
|
||||
base_tool: BaseTool,
|
||||
user_id: str | None,
|
||||
@@ -127,12 +152,54 @@ async def _execute_tool_sync(
|
||||
if pending is not None:
|
||||
pending[base_tool.name] = text
|
||||
|
||||
content_blocks: list[dict[str, str]] = [{"type": "text", "text": text}]
|
||||
|
||||
# If the tool result contains inline image data, add an MCP image block
|
||||
# so Claude can "see" the image (e.g. read_workspace_file on a small PNG).
|
||||
image_block = _extract_image_block(text)
|
||||
if image_block:
|
||||
content_blocks.append(image_block)
|
||||
|
||||
return {
|
||||
"content": [{"type": "text", "text": text}],
|
||||
"content": content_blocks,
|
||||
"isError": not result.success,
|
||||
}
|
||||
|
||||
|
||||
# MIME types that Claude can process as image content blocks.
|
||||
_SUPPORTED_IMAGE_TYPES = frozenset(
|
||||
{"image/png", "image/jpeg", "image/gif", "image/webp"}
|
||||
)
|
||||
|
||||
|
||||
def _extract_image_block(text: str) -> dict[str, str] | None:
|
||||
"""Extract an MCP image content block from a tool result JSON string.
|
||||
|
||||
Detects workspace file responses with ``content_base64`` and an image
|
||||
MIME type, returning an MCP-format image block that allows Claude to
|
||||
"see" the image. Returns ``None`` if the result is not an inline image.
|
||||
"""
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
mime_type = data.get("mime_type", "")
|
||||
base64_content = data.get("content_base64", "")
|
||||
|
||||
if mime_type in _SUPPORTED_IMAGE_TYPES and base64_content:
|
||||
return {
|
||||
"type": "image",
|
||||
"data": base64_content,
|
||||
"mimeType": mime_type,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _mcp_error(message: str) -> dict[str, Any]:
|
||||
return {
|
||||
"content": [
|
||||
@@ -311,14 +378,29 @@ def create_copilot_mcp_server():
|
||||
# which provides kernel-level network isolation via unshare --net.
|
||||
# Task allows spawning sub-agents (rate-limited by security hooks).
|
||||
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
|
||||
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task", "WebSearch"]
|
||||
# TodoWrite manages the task checklist shown in the UI — no security concern.
|
||||
_SDK_BUILTIN_TOOLS = [
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
"Task",
|
||||
"WebSearch",
|
||||
"TodoWrite",
|
||||
]
|
||||
|
||||
# SDK built-in tools that must be explicitly blocked.
|
||||
# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level
|
||||
# network isolation (unshare --net) instead.
|
||||
# WebFetch: SSRF risk — can reach internal network (localhost, 10.x, etc.).
|
||||
# Agent uses the SSRF-protected mcp__copilot__web_fetch tool instead.
|
||||
SDK_DISALLOWED_TOOLS = ["Bash", "WebFetch"]
|
||||
# AskUserQuestion: interactive CLI tool — no terminal in copilot context.
|
||||
SDK_DISALLOWED_TOOLS = [
|
||||
"Bash",
|
||||
"WebFetch",
|
||||
"AskUserQuestion",
|
||||
]
|
||||
|
||||
# Tools that are blocked entirely in security hooks (defence-in-depth).
|
||||
# Includes SDK_DISALLOWED_TOOLS plus common aliases/synonyms.
|
||||
|
||||
@@ -14,6 +14,8 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,6 +33,16 @@ STRIPPABLE_TYPES = frozenset(
|
||||
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptDownload:
|
||||
"""Result of downloading a transcript with its metadata."""
|
||||
|
||||
content: str
|
||||
message_count: int = 0 # session.messages length when uploaded
|
||||
uploaded_at: float = 0.0 # epoch timestamp of upload
|
||||
|
||||
|
||||
# Workspace storage constants — deterministic path from session_id.
|
||||
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
|
||||
|
||||
@@ -119,23 +131,19 @@ def read_transcript_file(transcript_path: str) -> str | None:
|
||||
content = f.read()
|
||||
|
||||
if not content.strip():
|
||||
logger.debug(f"[Transcript] Empty file: {transcript_path}")
|
||||
return None
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
if len(lines) < 3:
|
||||
# Raw files with ≤2 lines are metadata-only
|
||||
# (queue-operation + file-history-snapshot, no conversation).
|
||||
logger.debug(
|
||||
f"[Transcript] Too few lines ({len(lines)}): {transcript_path}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Quick structural validation — parse first and last lines.
|
||||
json.loads(lines[0])
|
||||
json.loads(lines[-1])
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[Transcript] Read {len(lines)} lines, "
|
||||
f"{len(content)} bytes from {transcript_path}"
|
||||
)
|
||||
@@ -160,6 +168,39 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The CLI replaces all non-alphanumeric characters with ``-``.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory.
|
||||
|
||||
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
|
||||
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
|
||||
safe to remove entirely after the transcript has been uploaded.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
cwd_encoded = _encode_cwd_for_cli(sdk_cwd)
|
||||
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||
|
||||
if not project_dir.startswith(projects_base + os.sep):
|
||||
logger.warning(
|
||||
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
|
||||
)
|
||||
return
|
||||
|
||||
if os.path.isdir(project_dir):
|
||||
shutil.rmtree(project_dir, ignore_errors=True)
|
||||
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
|
||||
|
||||
|
||||
def write_transcript_to_tempfile(
|
||||
transcript_content: str,
|
||||
session_id: str,
|
||||
@@ -191,7 +232,7 @@ def write_transcript_to_tempfile(
|
||||
with open(jsonl_path, "w") as f:
|
||||
f.write(transcript_content)
|
||||
|
||||
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
||||
logger.debug(f"[Transcript] Wrote resume file: {jsonl_path}")
|
||||
return jsonl_path
|
||||
|
||||
except OSError as e:
|
||||
@@ -248,6 +289,15 @@ def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
)
|
||||
|
||||
|
||||
def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
"""Return (workspace_id, file_id, filename) for a session's transcript metadata."""
|
||||
return (
|
||||
TRANSCRIPT_STORAGE_PREFIX,
|
||||
_sanitize_id(user_id),
|
||||
f"{_sanitize_id(session_id)}.meta.json",
|
||||
)
|
||||
|
||||
|
||||
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
"""Build the full storage path string that ``retrieve()`` expects.
|
||||
|
||||
@@ -268,21 +318,30 @@ def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
return f"local://{wid}/{fid}/{fname}"
|
||||
|
||||
|
||||
async def upload_transcript(user_id: str, session_id: str, content: str) -> None:
|
||||
async def upload_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
content: str,
|
||||
message_count: int = 0,
|
||||
) -> None:
|
||||
"""Strip progress entries and upload transcript to bucket storage.
|
||||
|
||||
Safety: only overwrites when the new (stripped) transcript is larger than
|
||||
what is already stored. Since JSONL is append-only, the latest transcript
|
||||
is always the longest. This prevents a slow/stale background task from
|
||||
clobbering a newer upload from a concurrent turn.
|
||||
|
||||
Args:
|
||||
message_count: ``len(session.messages)`` at upload time — used by
|
||||
the next turn to detect staleness and compress only the gap.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
stripped = strip_progress_entries(content)
|
||||
if not validate_transcript(stripped):
|
||||
logger.warning(
|
||||
f"[Transcript] Skipping upload — stripped content is not a valid "
|
||||
f"transcript for session {session_id}"
|
||||
f"[Transcript] Skipping upload — stripped content not valid "
|
||||
f"for session {session_id}"
|
||||
)
|
||||
return
|
||||
|
||||
@@ -296,10 +355,9 @@ async def upload_transcript(user_id: str, session_id: str, content: str) -> None
|
||||
try:
|
||||
existing = await storage.retrieve(path)
|
||||
if len(existing) >= new_size:
|
||||
logger.info(
|
||||
f"[Transcript] Skipping upload — existing transcript "
|
||||
f"({len(existing)}B) >= new ({new_size}B) for session "
|
||||
f"{session_id}"
|
||||
logger.debug(
|
||||
f"[Transcript] Skipping upload — existing ({len(existing)}B) "
|
||||
f">= new ({new_size}B) for session {session_id}"
|
||||
)
|
||||
return
|
||||
except (FileNotFoundError, Exception):
|
||||
@@ -311,16 +369,32 @@ async def upload_transcript(user_id: str, session_id: str, content: str) -> None
|
||||
filename=fname,
|
||||
content=encoded,
|
||||
)
|
||||
|
||||
# Store metadata alongside the transcript so the next turn can detect
|
||||
# staleness and only compress the gap instead of the full history.
|
||||
meta = {"message_count": message_count, "uploaded_at": time.time()}
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
await storage.store(
|
||||
workspace_id=mwid,
|
||||
file_id=mfid,
|
||||
filename=mfname,
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Uploaded {new_size} bytes "
|
||||
f"(stripped from {len(content)}) for session {session_id}"
|
||||
f"[Transcript] Uploaded {new_size}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count}) "
|
||||
f"for session {session_id}"
|
||||
)
|
||||
|
||||
|
||||
async def download_transcript(user_id: str, session_id: str) -> str | None:
|
||||
"""Download transcript from bucket storage.
|
||||
async def download_transcript(
|
||||
user_id: str, session_id: str
|
||||
) -> TranscriptDownload | None:
|
||||
"""Download transcript and metadata from bucket storage.
|
||||
|
||||
Returns the JSONL content string, or ``None`` if not found.
|
||||
Returns a ``TranscriptDownload`` with the JSONL content and the
|
||||
``message_count`` watermark from the upload, or ``None`` if not found.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
@@ -330,10 +404,6 @@ async def download_transcript(user_id: str, session_id: str) -> str | None:
|
||||
try:
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
logger.info(
|
||||
f"[Transcript] Downloaded {len(content)} bytes for session {session_id}"
|
||||
)
|
||||
return content
|
||||
except FileNotFoundError:
|
||||
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
|
||||
return None
|
||||
@@ -341,6 +411,36 @@ async def download_transcript(user_id: str, session_id: str) -> str | None:
|
||||
logger.warning(f"[Transcript] Failed to download transcript: {e}")
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
message_count = 0
|
||||
uploaded_at = 0.0
|
||||
try:
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
if isinstance(storage, GCSWorkspaceStorage):
|
||||
blob = f"workspaces/{mwid}/{mfid}/{mfname}"
|
||||
meta_path = f"gcs://{storage.bucket_name}/{blob}"
|
||||
else:
|
||||
meta_path = f"local://{mwid}/{mfid}/{mfname}"
|
||||
|
||||
meta_data = await storage.retrieve(meta_path)
|
||||
meta = json.loads(meta_data.decode("utf-8"))
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except (FileNotFoundError, json.JSONDecodeError, Exception):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
|
||||
logger.debug(
|
||||
f"[Transcript] Downloaded {len(content)}B "
|
||||
f"(msg_count={message_count}) for session {session_id}"
|
||||
)
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
uploaded_at=uploaded_at,
|
||||
)
|
||||
|
||||
|
||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
"""Delete transcript from bucket storage (e.g. after resume failure)."""
|
||||
|
||||
@@ -387,7 +387,7 @@ async def stream_chat_completion(
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
|
||||
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
|
||||
extra={
|
||||
@@ -404,7 +404,7 @@ async def stream_chat_completion(
|
||||
fetch_start = time.monotonic()
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
fetch_time = (time.monotonic() - fetch_start) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
|
||||
f"n_messages={len(session.messages) if session else 0}",
|
||||
extra={
|
||||
@@ -416,7 +416,7 @@ async def stream_chat_completion(
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] Using provided session, messages={len(session.messages)}",
|
||||
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||
)
|
||||
@@ -450,7 +450,7 @@ async def stream_chat_completion(
|
||||
message_length=len(message),
|
||||
)
|
||||
posthog_time = (time.monotonic() - posthog_start) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
|
||||
)
|
||||
@@ -458,7 +458,7 @@ async def stream_chat_completion(
|
||||
upsert_start = time.monotonic()
|
||||
session = await upsert_chat_session(session)
|
||||
upsert_time = (time.monotonic() - upsert_start) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
|
||||
)
|
||||
@@ -503,7 +503,7 @@ async def stream_chat_completion(
|
||||
prompt_start = time.monotonic()
|
||||
system_prompt, understanding = await _build_system_prompt(user_id)
|
||||
prompt_time = (time.monotonic() - prompt_start) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
|
||||
)
|
||||
@@ -537,7 +537,7 @@ async def stream_chat_completion(
|
||||
|
||||
# Only yield message start for the initial call, not for continuations.
|
||||
setup_time = (time.monotonic() - completion_start) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||
)
|
||||
@@ -548,7 +548,7 @@ async def stream_chat_completion(
|
||||
yield StreamStartStep()
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
"[TIMING] Calling _stream_chat_chunks",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
@@ -988,7 +988,7 @@ async def _stream_chat_chunks(
|
||||
if session.user_id:
|
||||
log_meta["user_id"] = session.user_id
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
|
||||
f"user={session.user_id}, n_messages={len(session.messages)}",
|
||||
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||
@@ -1011,7 +1011,7 @@ async def _stream_chat_chunks(
|
||||
base_url=config.base_url,
|
||||
)
|
||||
context_time = (time_module.perf_counter() - context_start) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
|
||||
)
|
||||
@@ -1053,7 +1053,7 @@ async def _stream_chat_chunks(
|
||||
retry_info = (
|
||||
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
|
||||
)
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -1093,7 +1093,7 @@ async def _stream_chat_chunks(
|
||||
extra_body=extra_body,
|
||||
)
|
||||
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
|
||||
)
|
||||
@@ -1142,7 +1142,7 @@ async def _stream_chat_chunks(
|
||||
ttfc = (
|
||||
time_module.perf_counter() - api_call_start
|
||||
) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
|
||||
f"(since API call), n_chunks={chunk_count}",
|
||||
extra={
|
||||
@@ -1210,7 +1210,7 @@ async def _stream_chat_chunks(
|
||||
)
|
||||
emitted_start_for_idx.add(idx)
|
||||
stream_duration = time_module.perf_counter() - api_call_start
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
|
||||
f"duration={stream_duration:.2f}s, "
|
||||
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
|
||||
@@ -1244,7 +1244,7 @@ async def _stream_chat_chunks(
|
||||
raise
|
||||
|
||||
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
|
||||
f"session={session.session_id}, user={session.user_id}",
|
||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||
@@ -1494,8 +1494,8 @@ async def _yield_tool_call(
|
||||
# Mark stream registry task as failed if it was created
|
||||
try:
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as mark_err:
|
||||
logger.warning(f"Failed to mark task {task_id} as failed: {mark_err}")
|
||||
logger.error(
|
||||
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
||||
)
|
||||
|
||||
@@ -143,7 +143,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
"Transcript was not uploaded to bucket after turn 1 — "
|
||||
"Stop hook may not have fired or transcript was too small"
|
||||
)
|
||||
logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes")
|
||||
logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes")
|
||||
|
||||
# Reload session for turn 2
|
||||
session = await get_chat_session(session.session_id, test_user_id)
|
||||
|
||||
@@ -117,7 +117,7 @@ async def create_task(
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
@@ -135,7 +135,7 @@ async def create_task(
|
||||
redis_start = time.perf_counter()
|
||||
redis = await get_redis_async()
|
||||
redis_time = (time.perf_counter() - redis_start) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
|
||||
)
|
||||
@@ -158,7 +158,7 @@ async def create_task(
|
||||
},
|
||||
)
|
||||
hset_time = (time.perf_counter() - hset_start) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] redis.hset took {hset_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
|
||||
)
|
||||
@@ -169,7 +169,7 @@ async def create_task(
|
||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
|
||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||
)
|
||||
@@ -230,7 +230,7 @@ async def publish_chunk(
|
||||
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
|
||||
or total_time > 50
|
||||
):
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -279,7 +279,7 @@ async def subscribe_to_task(
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
|
||||
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
|
||||
)
|
||||
@@ -289,14 +289,14 @@ async def subscribe_to_task(
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
hgetall_time = (time.perf_counter() - redis_start) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
||||
)
|
||||
|
||||
if not meta:
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -335,7 +335,7 @@ async def subscribe_to_task(
|
||||
xread_start = time.perf_counter()
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -363,7 +363,7 @@ async def subscribe_to_task(
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay message: {e}")
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -376,7 +376,7 @@ async def subscribe_to_task(
|
||||
|
||||
# Step 2: If task is still running, start stream listener for live updates
|
||||
if task_status == "running":
|
||||
logger.info(
|
||||
logger.debug(
|
||||
"[TIMING] Task still running, starting _stream_listener",
|
||||
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||
)
|
||||
@@ -387,14 +387,14 @@ async def subscribe_to_task(
|
||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||
else:
|
||||
# Task is completed/failed - add finish marker
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] Task already {task_status}, adding StreamFinish",
|
||||
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||
)
|
||||
await subscriber_queue.put(StreamFinish())
|
||||
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
|
||||
f"n_messages_replayed={replayed_count}",
|
||||
extra={
|
||||
@@ -433,7 +433,7 @@ async def _stream_listener(
|
||||
if log_meta is None:
|
||||
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
|
||||
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
|
||||
)
|
||||
@@ -462,7 +462,7 @@ async def _stream_listener(
|
||||
|
||||
if messages:
|
||||
msg_count = sum(len(msgs) for _, msgs in messages)
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -475,7 +475,7 @@ async def _stream_listener(
|
||||
)
|
||||
elif xread_time > 1000:
|
||||
# Only log timeouts (30s blocking)
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -526,7 +526,7 @@ async def _stream_listener(
|
||||
if first_message_time is None:
|
||||
first_message_time = time.perf_counter()
|
||||
elapsed = (first_message_time - start_time) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -568,7 +568,7 @@ async def _stream_listener(
|
||||
# Stop listening on finish
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -587,7 +587,7 @@ async def _stream_listener(
|
||||
|
||||
except asyncio.CancelledError:
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -619,7 +619,7 @@ async def _stream_listener(
|
||||
finally:
|
||||
# Clean up listener task mapping on exit
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
|
||||
f"delivered={messages_delivered}, xread_count={xread_count}",
|
||||
extra={
|
||||
@@ -829,10 +829,13 @@ async def get_active_task_for_session(
|
||||
)
|
||||
await mark_task_completed(task_id, "failed")
|
||||
continue
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
except (ValueError, TypeError) as exc:
|
||||
logger.warning(
|
||||
f"[TASK_LOOKUP] Failed to parse created_at "
|
||||
f"for task {task_id[:8]}...: {exc}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
|
||||
)
|
||||
|
||||
|
||||
@@ -33,7 +33,6 @@ query SearchFeatureRequests($term: String!, $filter: IssueFilter, $first: Int) {
|
||||
id
|
||||
identifier
|
||||
title
|
||||
description
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -205,7 +204,6 @@ class SearchFeatureRequestsTool(BaseTool):
|
||||
id=node["id"],
|
||||
identifier=node["identifier"],
|
||||
title=node["title"],
|
||||
description=node.get("description"),
|
||||
)
|
||||
for node in nodes
|
||||
]
|
||||
@@ -239,7 +237,11 @@ class CreateFeatureRequestTool(BaseTool):
|
||||
"Create a new feature request or add a customer need to an existing one. "
|
||||
"Always search first with search_feature_requests to avoid duplicates. "
|
||||
"If a matching request exists, pass its ID as existing_issue_id to add "
|
||||
"the user's need to it instead of creating a duplicate."
|
||||
"the user's need to it instead of creating a duplicate. "
|
||||
"IMPORTANT: Never include personally identifiable information (PII) in "
|
||||
"the title or description — no names, emails, phone numbers, company "
|
||||
"names, or other identifying details. Write titles and descriptions in "
|
||||
"generic, feature-focused language."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -249,11 +251,20 @@ class CreateFeatureRequestTool(BaseTool):
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Title for the feature request.",
|
||||
"description": (
|
||||
"Title for the feature request. Must be generic and "
|
||||
"feature-focused — do not include any user names, emails, "
|
||||
"company names, or other PII."
|
||||
),
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Detailed description of what the user wants and why.",
|
||||
"description": (
|
||||
"Detailed description of what the user wants and why. "
|
||||
"Must not contain any personally identifiable information "
|
||||
"(PII) — describe the feature need generically without "
|
||||
"referencing specific users, companies, or contact details."
|
||||
),
|
||||
},
|
||||
"existing_issue_id": {
|
||||
"type": "string",
|
||||
|
||||
@@ -117,13 +117,11 @@ class TestSearchFeatureRequestsTool:
|
||||
"id": "id-1",
|
||||
"identifier": "FR-1",
|
||||
"title": "Dark mode",
|
||||
"description": "Add dark mode support",
|
||||
},
|
||||
{
|
||||
"id": "id-2",
|
||||
"identifier": "FR-2",
|
||||
"title": "Dark theme",
|
||||
"description": None,
|
||||
},
|
||||
]
|
||||
patcher, _ = _mock_linear_config(query_return=_search_response(nodes))
|
||||
|
||||
@@ -486,7 +486,6 @@ class FeatureRequestInfo(BaseModel):
|
||||
id: str
|
||||
identifier: str
|
||||
title: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class FeatureRequestSearchResponse(ToolResponseBase):
|
||||
|
||||
@@ -93,7 +93,15 @@ from backend.data.user import (
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.data.workspace import (
|
||||
count_workspace_files,
|
||||
create_workspace_file,
|
||||
get_or_create_workspace,
|
||||
get_workspace_file,
|
||||
get_workspace_file_by_path,
|
||||
list_workspace_files,
|
||||
soft_delete_workspace_file,
|
||||
)
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
@@ -274,7 +282,13 @@ class DatabaseManager(AppService):
|
||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||
|
||||
# ============ Workspace ============ #
|
||||
count_workspace_files = _(count_workspace_files)
|
||||
create_workspace_file = _(create_workspace_file)
|
||||
get_or_create_workspace = _(get_or_create_workspace)
|
||||
get_workspace_file = _(get_workspace_file)
|
||||
get_workspace_file_by_path = _(get_workspace_file_by_path)
|
||||
list_workspace_files = _(list_workspace_files)
|
||||
soft_delete_workspace_file = _(soft_delete_workspace_file)
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = _(get_business_understanding)
|
||||
@@ -438,7 +452,13 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_user_execution_summary_data = d.get_user_execution_summary_data
|
||||
|
||||
# ============ Workspace ============ #
|
||||
count_workspace_files = d.count_workspace_files
|
||||
create_workspace_file = d.create_workspace_file
|
||||
get_or_create_workspace = d.get_or_create_workspace
|
||||
get_workspace_file = d.get_workspace_file
|
||||
get_workspace_file_by_path = d.get_workspace_file_by_path
|
||||
list_workspace_files = d.list_workspace_files
|
||||
soft_delete_workspace_file = d.soft_delete_workspace_file
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = d.get_business_understanding
|
||||
|
||||
@@ -164,21 +164,23 @@ async def create_workspace_file(
|
||||
|
||||
async def get_workspace_file(
|
||||
file_id: str,
|
||||
workspace_id: Optional[str] = None,
|
||||
workspace_id: str,
|
||||
) -> Optional[WorkspaceFile]:
|
||||
"""
|
||||
Get a workspace file by ID.
|
||||
|
||||
Args:
|
||||
file_id: The file ID
|
||||
workspace_id: Optional workspace ID for validation
|
||||
workspace_id: Workspace ID for scoping (required)
|
||||
|
||||
Returns:
|
||||
WorkspaceFile instance or None
|
||||
"""
|
||||
where_clause: dict = {"id": file_id, "isDeleted": False}
|
||||
if workspace_id:
|
||||
where_clause["workspaceId"] = workspace_id
|
||||
where_clause: UserWorkspaceFileWhereInput = {
|
||||
"id": file_id,
|
||||
"isDeleted": False,
|
||||
"workspaceId": workspace_id,
|
||||
}
|
||||
|
||||
file = await UserWorkspaceFile.prisma().find_first(where=where_clause)
|
||||
return WorkspaceFile.from_db(file) if file else None
|
||||
@@ -268,7 +270,7 @@ async def count_workspace_files(
|
||||
Returns:
|
||||
Number of files
|
||||
"""
|
||||
where_clause: dict = {"workspaceId": workspace_id}
|
||||
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
|
||||
if not include_deleted:
|
||||
where_clause["isDeleted"] = False
|
||||
|
||||
@@ -283,7 +285,7 @@ async def count_workspace_files(
|
||||
|
||||
async def soft_delete_workspace_file(
|
||||
file_id: str,
|
||||
workspace_id: Optional[str] = None,
|
||||
workspace_id: str,
|
||||
) -> Optional[WorkspaceFile]:
|
||||
"""
|
||||
Soft-delete a workspace file.
|
||||
@@ -293,7 +295,7 @@ async def soft_delete_workspace_file(
|
||||
|
||||
Args:
|
||||
file_id: The file ID
|
||||
workspace_id: Optional workspace ID for validation
|
||||
workspace_id: Workspace ID for scoping (required)
|
||||
|
||||
Returns:
|
||||
Updated WorkspaceFile instance or None if not found
|
||||
|
||||
@@ -28,7 +28,7 @@ from typing import (
|
||||
import httpx
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, responses
|
||||
from prisma.errors import DataError
|
||||
from prisma.errors import DataError, UniqueViolationError
|
||||
from pydantic import BaseModel, TypeAdapter, create_model
|
||||
|
||||
import backend.util.exceptions as exceptions
|
||||
@@ -201,6 +201,7 @@ EXCEPTION_MAPPING = {
|
||||
UnhealthyServiceError,
|
||||
HTTPClientError,
|
||||
HTTPServerError,
|
||||
UniqueViolationError,
|
||||
*[
|
||||
ErrorType
|
||||
for _, ErrorType in inspect.getmembers(exceptions)
|
||||
@@ -416,6 +417,9 @@ class AppService(BaseAppService, ABC):
|
||||
self.fastapi_app.add_exception_handler(
|
||||
DataError, self._handle_internal_http_error(400)
|
||||
)
|
||||
self.fastapi_app.add_exception_handler(
|
||||
UniqueViolationError, self._handle_internal_http_error(400)
|
||||
)
|
||||
self.fastapi_app.add_exception_handler(
|
||||
Exception, self._handle_internal_http_error(500)
|
||||
)
|
||||
@@ -478,6 +482,7 @@ def get_service_client(
|
||||
# Don't retry these specific exceptions that won't be fixed by retrying
|
||||
ValueError, # Invalid input/parameters
|
||||
DataError, # Prisma data integrity errors (foreign key, unique constraints)
|
||||
UniqueViolationError, # Unique constraint violations
|
||||
KeyError, # Missing required data
|
||||
TypeError, # Wrong data types
|
||||
AttributeError, # Missing attributes
|
||||
|
||||
@@ -12,15 +12,8 @@ from typing import Optional
|
||||
|
||||
from prisma.errors import UniqueViolationError
|
||||
|
||||
from backend.data.workspace import (
|
||||
WorkspaceFile,
|
||||
count_workspace_files,
|
||||
create_workspace_file,
|
||||
get_workspace_file,
|
||||
get_workspace_file_by_path,
|
||||
list_workspace_files,
|
||||
soft_delete_workspace_file,
|
||||
)
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
||||
@@ -125,8 +118,9 @@ class WorkspaceManager:
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
db = workspace_db()
|
||||
resolved_path = self._resolve_path(path)
|
||||
file = await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||
file = await db.get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||
if file is None:
|
||||
raise FileNotFoundError(f"File not found at path: {resolved_path}")
|
||||
|
||||
@@ -146,7 +140,8 @@ class WorkspaceManager:
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
file = await get_workspace_file(file_id, self.workspace_id)
|
||||
db = workspace_db()
|
||||
file = await db.get_workspace_file(file_id, self.workspace_id)
|
||||
if file is None:
|
||||
raise FileNotFoundError(f"File not found: {file_id}")
|
||||
|
||||
@@ -204,8 +199,10 @@ class WorkspaceManager:
|
||||
# For overwrite=True, we let the write proceed and handle via UniqueViolationError
|
||||
# This ensures the new file is written to storage BEFORE the old one is deleted,
|
||||
# preventing data loss if the new write fails
|
||||
db = workspace_db()
|
||||
|
||||
if not overwrite:
|
||||
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
||||
existing = await db.get_workspace_file_by_path(self.workspace_id, path)
|
||||
if existing is not None:
|
||||
raise ValueError(f"File already exists at path: {path}")
|
||||
|
||||
@@ -232,7 +229,7 @@ class WorkspaceManager:
|
||||
# Create database record - handle race condition where another request
|
||||
# created a file at the same path between our check and create
|
||||
try:
|
||||
file = await create_workspace_file(
|
||||
file = await db.create_workspace_file(
|
||||
workspace_id=self.workspace_id,
|
||||
file_id=file_id,
|
||||
name=filename,
|
||||
@@ -246,12 +243,12 @@ class WorkspaceManager:
|
||||
# Race condition: another request created a file at this path
|
||||
if overwrite:
|
||||
# Re-fetch and delete the conflicting file, then retry
|
||||
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
||||
existing = await db.get_workspace_file_by_path(self.workspace_id, path)
|
||||
if existing:
|
||||
await self.delete_file(existing.id)
|
||||
# Retry the create - if this also fails, clean up storage file
|
||||
try:
|
||||
file = await create_workspace_file(
|
||||
file = await db.create_workspace_file(
|
||||
workspace_id=self.workspace_id,
|
||||
file_id=file_id,
|
||||
name=filename,
|
||||
@@ -314,8 +311,9 @@ class WorkspaceManager:
|
||||
List of WorkspaceFile instances
|
||||
"""
|
||||
effective_path = self._get_effective_path(path, include_all_sessions)
|
||||
db = workspace_db()
|
||||
|
||||
return await list_workspace_files(
|
||||
return await db.list_workspace_files(
|
||||
workspace_id=self.workspace_id,
|
||||
path_prefix=effective_path,
|
||||
limit=limit,
|
||||
@@ -332,7 +330,8 @@ class WorkspaceManager:
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
file = await get_workspace_file(file_id, self.workspace_id)
|
||||
db = workspace_db()
|
||||
file = await db.get_workspace_file(file_id, self.workspace_id)
|
||||
if file is None:
|
||||
return False
|
||||
|
||||
@@ -345,7 +344,7 @@ class WorkspaceManager:
|
||||
# Continue with database soft-delete even if storage delete fails
|
||||
|
||||
# Soft-delete database record
|
||||
result = await soft_delete_workspace_file(file_id, self.workspace_id)
|
||||
result = await db.soft_delete_workspace_file(file_id, self.workspace_id)
|
||||
return result is not None
|
||||
|
||||
async def get_download_url(self, file_id: str, expires_in: int = 3600) -> str:
|
||||
@@ -362,7 +361,8 @@ class WorkspaceManager:
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
file = await get_workspace_file(file_id, self.workspace_id)
|
||||
db = workspace_db()
|
||||
file = await db.get_workspace_file(file_id, self.workspace_id)
|
||||
if file is None:
|
||||
raise FileNotFoundError(f"File not found: {file_id}")
|
||||
|
||||
@@ -379,7 +379,8 @@ class WorkspaceManager:
|
||||
Returns:
|
||||
WorkspaceFile instance or None
|
||||
"""
|
||||
return await get_workspace_file(file_id, self.workspace_id)
|
||||
db = workspace_db()
|
||||
return await db.get_workspace_file(file_id, self.workspace_id)
|
||||
|
||||
async def get_file_info_by_path(self, path: str) -> Optional[WorkspaceFile]:
|
||||
"""
|
||||
@@ -394,8 +395,9 @@ class WorkspaceManager:
|
||||
Returns:
|
||||
WorkspaceFile instance or None
|
||||
"""
|
||||
db = workspace_db()
|
||||
resolved_path = self._resolve_path(path)
|
||||
return await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||
return await db.get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||
|
||||
async def get_file_count(
|
||||
self,
|
||||
@@ -417,7 +419,8 @@ class WorkspaceManager:
|
||||
Number of files
|
||||
"""
|
||||
effective_path = self._get_effective_path(path, include_all_sessions)
|
||||
db = workspace_db()
|
||||
|
||||
return await count_workspace_files(
|
||||
return await db.count_workspace_files(
|
||||
self.workspace_id, path_prefix=effective_path
|
||||
)
|
||||
|
||||
@@ -93,7 +93,14 @@ class WorkspaceStorageBackend(ABC):
|
||||
|
||||
|
||||
class GCSWorkspaceStorage(WorkspaceStorageBackend):
|
||||
"""Google Cloud Storage implementation for workspace storage."""
|
||||
"""Google Cloud Storage implementation for workspace storage.
|
||||
|
||||
Each instance owns a single ``aiohttp.ClientSession`` and GCS async
|
||||
client. Because ``ClientSession`` is bound to the event loop on which it
|
||||
was created, callers that run on separate loops (e.g. copilot executor
|
||||
worker threads) **must** obtain their own ``GCSWorkspaceStorage`` instance
|
||||
via :func:`get_workspace_storage` which is event-loop-aware.
|
||||
"""
|
||||
|
||||
def __init__(self, bucket_name: str):
|
||||
self.bucket_name = bucket_name
|
||||
@@ -337,60 +344,73 @@ class LocalWorkspaceStorage(WorkspaceStorageBackend):
|
||||
raise ValueError(f"Invalid storage path format: {storage_path}")
|
||||
|
||||
|
||||
# Global storage backend instance
|
||||
_workspace_storage: Optional[WorkspaceStorageBackend] = None
|
||||
# ---------------------------------------------------------------------------
|
||||
# Storage instance management
|
||||
# ---------------------------------------------------------------------------
|
||||
# ``aiohttp.ClientSession`` is bound to the event loop where it is created.
|
||||
# The copilot executor runs each worker in its own thread with a dedicated
|
||||
# event loop, so a single global ``GCSWorkspaceStorage`` instance would break.
|
||||
#
|
||||
# For **local storage** a single shared instance is fine (no async I/O).
|
||||
# For **GCS storage** we keep one instance *per event loop* so every loop
|
||||
# gets its own ``ClientSession``.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_local_storage: Optional[LocalWorkspaceStorage] = None
|
||||
_gcs_storages: dict[int, GCSWorkspaceStorage] = {}
|
||||
_storage_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_workspace_storage() -> WorkspaceStorageBackend:
|
||||
"""Return a workspace storage backend for the **current** event loop.
|
||||
|
||||
* Local storage → single shared instance (no event-loop affinity).
|
||||
* GCS storage → one instance per event loop to avoid cross-loop
|
||||
``aiohttp`` errors.
|
||||
"""
|
||||
Get the workspace storage backend instance.
|
||||
global _local_storage
|
||||
|
||||
Uses GCS if media_gcs_bucket_name is configured, otherwise uses local storage.
|
||||
"""
|
||||
global _workspace_storage
|
||||
config = Config()
|
||||
|
||||
if _workspace_storage is None:
|
||||
async with _storage_lock:
|
||||
if _workspace_storage is None:
|
||||
config = Config()
|
||||
# --- Local storage (shared) ---
|
||||
if not config.media_gcs_bucket_name:
|
||||
if _local_storage is None:
|
||||
storage_dir = (
|
||||
config.workspace_storage_dir if config.workspace_storage_dir else None
|
||||
)
|
||||
logger.info(f"Using local workspace storage: {storage_dir or 'default'}")
|
||||
_local_storage = LocalWorkspaceStorage(storage_dir)
|
||||
return _local_storage
|
||||
|
||||
if config.media_gcs_bucket_name:
|
||||
logger.info(
|
||||
f"Using GCS workspace storage: {config.media_gcs_bucket_name}"
|
||||
)
|
||||
_workspace_storage = GCSWorkspaceStorage(
|
||||
config.media_gcs_bucket_name
|
||||
)
|
||||
else:
|
||||
storage_dir = (
|
||||
config.workspace_storage_dir
|
||||
if config.workspace_storage_dir
|
||||
else None
|
||||
)
|
||||
logger.info(
|
||||
f"Using local workspace storage: {storage_dir or 'default'}"
|
||||
)
|
||||
_workspace_storage = LocalWorkspaceStorage(storage_dir)
|
||||
|
||||
return _workspace_storage
|
||||
# --- GCS storage (per event loop) ---
|
||||
loop_id = id(asyncio.get_running_loop())
|
||||
if loop_id not in _gcs_storages:
|
||||
logger.info(
|
||||
f"Creating GCS workspace storage for loop {loop_id}: "
|
||||
f"{config.media_gcs_bucket_name}"
|
||||
)
|
||||
_gcs_storages[loop_id] = GCSWorkspaceStorage(config.media_gcs_bucket_name)
|
||||
return _gcs_storages[loop_id]
|
||||
|
||||
|
||||
async def shutdown_workspace_storage() -> None:
|
||||
"""
|
||||
Properly shutdown the global workspace storage backend.
|
||||
"""Shut down workspace storage for the **current** event loop.
|
||||
|
||||
Closes aiohttp sessions and other resources for GCS backend.
|
||||
Should be called during application shutdown.
|
||||
Closes the ``aiohttp`` session owned by the current loop's GCS instance.
|
||||
Each worker thread should call this on its own loop before the loop is
|
||||
destroyed. The REST API lifespan hook calls it for the main server loop.
|
||||
"""
|
||||
global _workspace_storage
|
||||
global _local_storage
|
||||
|
||||
if _workspace_storage is not None:
|
||||
async with _storage_lock:
|
||||
if _workspace_storage is not None:
|
||||
if isinstance(_workspace_storage, GCSWorkspaceStorage):
|
||||
await _workspace_storage.close()
|
||||
_workspace_storage = None
|
||||
loop_id = id(asyncio.get_running_loop())
|
||||
storage = _gcs_storages.pop(loop_id, None)
|
||||
if storage is not None:
|
||||
await storage.close()
|
||||
|
||||
# Clear local storage only when the last GCS instance is gone
|
||||
# (i.e. full shutdown, not just a single worker stopping).
|
||||
if not _gcs_storages:
|
||||
_local_storage = None
|
||||
|
||||
|
||||
def compute_file_checksum(content: bytes) -> str:
|
||||
|
||||
@@ -1,21 +1,39 @@
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { Switch } from "@/components/atoms/Switch/Switch";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { CaretDownIcon } from "@phosphor-icons/react";
|
||||
|
||||
export const NodeAdvancedToggle = ({ nodeId }: { nodeId: string }) => {
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
};
|
||||
|
||||
export function NodeAdvancedToggle({ nodeId }: Props) {
|
||||
const showAdvanced = useNodeStore(
|
||||
(state) => state.nodeAdvancedStates[nodeId] || false,
|
||||
);
|
||||
const setShowAdvanced = useNodeStore((state) => state.setShowAdvanced);
|
||||
return (
|
||||
<div className="flex items-center justify-between gap-2 rounded-b-xlarge border-t border-zinc-200 bg-white px-5 py-3.5">
|
||||
<Text variant="body" className="font-medium text-slate-700">
|
||||
Advanced
|
||||
</Text>
|
||||
<Switch
|
||||
onCheckedChange={(checked) => setShowAdvanced(nodeId, checked)}
|
||||
checked={showAdvanced}
|
||||
/>
|
||||
<div className="flex items-center justify-start gap-2 bg-white px-5 pb-3.5">
|
||||
<Button
|
||||
variant="ghost"
|
||||
className="h-fit min-w-0 p-0 hover:border-transparent hover:bg-transparent"
|
||||
onClick={() => setShowAdvanced(nodeId, !showAdvanced)}
|
||||
aria-expanded={showAdvanced}
|
||||
>
|
||||
<Text
|
||||
variant="body"
|
||||
as="span"
|
||||
className="flex items-center gap-2 !font-semibold text-slate-700"
|
||||
>
|
||||
Advanced{" "}
|
||||
<CaretDownIcon
|
||||
size={16}
|
||||
weight="bold"
|
||||
className={`transition-transform ${showAdvanced ? "rotate-180" : ""}`}
|
||||
aria-hidden
|
||||
/>
|
||||
</Text>
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,63 +1,695 @@
|
||||
"use client";
|
||||
|
||||
import { ToolUIPart } from "ai";
|
||||
import { GearIcon } from "@phosphor-icons/react";
|
||||
import {
|
||||
CheckCircleIcon,
|
||||
CircleDashedIcon,
|
||||
CircleIcon,
|
||||
FileIcon,
|
||||
FilesIcon,
|
||||
GearIcon,
|
||||
GlobeIcon,
|
||||
ListChecksIcon,
|
||||
MagnifyingGlassIcon,
|
||||
PencilSimpleIcon,
|
||||
TerminalIcon,
|
||||
TrashIcon,
|
||||
WarningDiamondIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
|
||||
import {
|
||||
ContentCodeBlock,
|
||||
ContentMessage,
|
||||
} from "../../components/ToolAccordion/AccordionContent";
|
||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||
|
||||
interface Props {
|
||||
part: ToolUIPart;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Tool name helpers */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function extractToolName(part: ToolUIPart): string {
|
||||
// ToolUIPart.type is "tool-{name}", extract the name portion.
|
||||
return part.type.replace(/^tool-/, "");
|
||||
}
|
||||
|
||||
function formatToolName(name: string): string {
|
||||
// "search_docs" → "Search docs", "Read" → "Read"
|
||||
return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase());
|
||||
}
|
||||
|
||||
function getAnimationText(part: ToolUIPart): string {
|
||||
const label = formatToolName(extractToolName(part));
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Tool categorization */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
switch (part.state) {
|
||||
case "input-streaming":
|
||||
case "input-available":
|
||||
return `Running ${label}…`;
|
||||
case "output-available":
|
||||
return `${label} completed`;
|
||||
case "output-error":
|
||||
return `${label} failed`;
|
||||
type ToolCategory =
|
||||
| "bash"
|
||||
| "web"
|
||||
| "file-read"
|
||||
| "file-write"
|
||||
| "file-delete"
|
||||
| "file-list"
|
||||
| "search"
|
||||
| "edit"
|
||||
| "todo"
|
||||
| "other";
|
||||
|
||||
function getToolCategory(toolName: string): ToolCategory {
|
||||
switch (toolName) {
|
||||
case "bash_exec":
|
||||
return "bash";
|
||||
case "web_fetch":
|
||||
case "WebSearch":
|
||||
case "WebFetch":
|
||||
return "web";
|
||||
case "read_workspace_file":
|
||||
case "Read":
|
||||
return "file-read";
|
||||
case "write_workspace_file":
|
||||
case "Write":
|
||||
return "file-write";
|
||||
case "delete_workspace_file":
|
||||
return "file-delete";
|
||||
case "list_workspace_files":
|
||||
case "Glob":
|
||||
return "file-list";
|
||||
case "Grep":
|
||||
return "search";
|
||||
case "Edit":
|
||||
return "edit";
|
||||
case "TodoWrite":
|
||||
return "todo";
|
||||
default:
|
||||
return `Running ${label}…`;
|
||||
return "other";
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Tool icon */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function ToolIcon({
|
||||
category,
|
||||
isStreaming,
|
||||
isError,
|
||||
}: {
|
||||
category: ToolCategory;
|
||||
isStreaming: boolean;
|
||||
isError: boolean;
|
||||
}) {
|
||||
if (isError) {
|
||||
return (
|
||||
<WarningDiamondIcon size={14} weight="regular" className="text-red-500" />
|
||||
);
|
||||
}
|
||||
if (isStreaming) {
|
||||
return <OrbitLoader size={20} />;
|
||||
}
|
||||
|
||||
const iconClass = "text-neutral-400";
|
||||
switch (category) {
|
||||
case "bash":
|
||||
return <TerminalIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "web":
|
||||
return <GlobeIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "file-read":
|
||||
return <FileIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "file-write":
|
||||
return <FileIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "file-delete":
|
||||
return <TrashIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "file-list":
|
||||
return <FilesIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "search":
|
||||
return (
|
||||
<MagnifyingGlassIcon size={14} weight="regular" className={iconClass} />
|
||||
);
|
||||
case "edit":
|
||||
return (
|
||||
<PencilSimpleIcon size={14} weight="regular" className={iconClass} />
|
||||
);
|
||||
case "todo":
|
||||
return (
|
||||
<ListChecksIcon size={14} weight="regular" className={iconClass} />
|
||||
);
|
||||
default:
|
||||
return <GearIcon size={14} weight="regular" className={iconClass} />;
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Accordion icon (larger, for the accordion header) */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function AccordionIcon({ category }: { category: ToolCategory }) {
|
||||
switch (category) {
|
||||
case "bash":
|
||||
return <TerminalIcon size={32} weight="light" />;
|
||||
case "web":
|
||||
return <GlobeIcon size={32} weight="light" />;
|
||||
case "file-read":
|
||||
case "file-write":
|
||||
return <FileIcon size={32} weight="light" />;
|
||||
case "file-delete":
|
||||
return <TrashIcon size={32} weight="light" />;
|
||||
case "file-list":
|
||||
return <FilesIcon size={32} weight="light" />;
|
||||
case "search":
|
||||
return <MagnifyingGlassIcon size={32} weight="light" />;
|
||||
case "edit":
|
||||
return <PencilSimpleIcon size={32} weight="light" />;
|
||||
case "todo":
|
||||
return <ListChecksIcon size={32} weight="light" />;
|
||||
default:
|
||||
return <GearIcon size={32} weight="light" />;
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Input extraction */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function getInputSummary(toolName: string, input: unknown): string | null {
|
||||
if (!input || typeof input !== "object") return null;
|
||||
const inp = input as Record<string, unknown>;
|
||||
|
||||
switch (toolName) {
|
||||
case "bash_exec":
|
||||
return typeof inp.command === "string" ? inp.command : null;
|
||||
case "web_fetch":
|
||||
case "WebFetch":
|
||||
return typeof inp.url === "string" ? inp.url : null;
|
||||
case "WebSearch":
|
||||
return typeof inp.query === "string" ? inp.query : null;
|
||||
case "read_workspace_file":
|
||||
case "Read":
|
||||
return (
|
||||
(typeof inp.file_path === "string" ? inp.file_path : null) ??
|
||||
(typeof inp.path === "string" ? inp.path : null)
|
||||
);
|
||||
case "write_workspace_file":
|
||||
case "Write":
|
||||
return (
|
||||
(typeof inp.file_path === "string" ? inp.file_path : null) ??
|
||||
(typeof inp.path === "string" ? inp.path : null)
|
||||
);
|
||||
case "delete_workspace_file":
|
||||
return typeof inp.file_path === "string" ? inp.file_path : null;
|
||||
case "Glob":
|
||||
return typeof inp.pattern === "string" ? inp.pattern : null;
|
||||
case "Grep":
|
||||
return typeof inp.pattern === "string" ? inp.pattern : null;
|
||||
case "Edit":
|
||||
return typeof inp.file_path === "string" ? inp.file_path : null;
|
||||
case "TodoWrite": {
|
||||
// Extract the in-progress task name for the status line
|
||||
const todos = Array.isArray(inp.todos) ? inp.todos : [];
|
||||
const active = todos.find(
|
||||
(t: Record<string, unknown>) => t.status === "in_progress",
|
||||
);
|
||||
if (active && typeof active.activeForm === "string")
|
||||
return active.activeForm;
|
||||
if (active && typeof active.content === "string") return active.content;
|
||||
return null;
|
||||
}
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function truncate(text: string, maxLen: number): string {
|
||||
if (text.length <= maxLen) return text;
|
||||
return text.slice(0, maxLen).trimEnd() + "…";
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Animation text */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function getAnimationText(part: ToolUIPart, category: ToolCategory): string {
|
||||
const toolName = extractToolName(part);
|
||||
const summary = getInputSummary(toolName, part.input);
|
||||
const shortSummary = summary ? truncate(summary, 60) : null;
|
||||
|
||||
switch (part.state) {
|
||||
case "input-streaming":
|
||||
case "input-available": {
|
||||
switch (category) {
|
||||
case "bash":
|
||||
return shortSummary ? `Running: ${shortSummary}` : "Running command…";
|
||||
case "web":
|
||||
if (toolName === "WebSearch") {
|
||||
return shortSummary
|
||||
? `Searching "${shortSummary}"`
|
||||
: "Searching the web…";
|
||||
}
|
||||
return shortSummary
|
||||
? `Fetching ${shortSummary}`
|
||||
: "Fetching web content…";
|
||||
case "file-read":
|
||||
return shortSummary ? `Reading ${shortSummary}` : "Reading file…";
|
||||
case "file-write":
|
||||
return shortSummary ? `Writing ${shortSummary}` : "Writing file…";
|
||||
case "file-delete":
|
||||
return shortSummary ? `Deleting ${shortSummary}` : "Deleting file…";
|
||||
case "file-list":
|
||||
return shortSummary ? `Listing ${shortSummary}` : "Listing files…";
|
||||
case "search":
|
||||
return shortSummary
|
||||
? `Searching for "${shortSummary}"`
|
||||
: "Searching…";
|
||||
case "edit":
|
||||
return shortSummary ? `Editing ${shortSummary}` : "Editing file…";
|
||||
case "todo":
|
||||
return shortSummary ? `${shortSummary}` : "Updating task list…";
|
||||
default:
|
||||
return `Running ${formatToolName(toolName)}…`;
|
||||
}
|
||||
}
|
||||
case "output-available": {
|
||||
switch (category) {
|
||||
case "bash": {
|
||||
const exitCode = getExitCode(part.output);
|
||||
if (exitCode !== null && exitCode !== 0) {
|
||||
return `Command exited with code ${exitCode}`;
|
||||
}
|
||||
return shortSummary ? `Ran: ${shortSummary}` : "Command completed";
|
||||
}
|
||||
case "web":
|
||||
if (toolName === "WebSearch") {
|
||||
return shortSummary
|
||||
? `Searched "${shortSummary}"`
|
||||
: "Web search completed";
|
||||
}
|
||||
return shortSummary
|
||||
? `Fetched ${shortSummary}`
|
||||
: "Fetched web content";
|
||||
case "file-read":
|
||||
return shortSummary ? `Read ${shortSummary}` : "File read completed";
|
||||
case "file-write":
|
||||
return shortSummary ? `Wrote ${shortSummary}` : "File written";
|
||||
case "file-delete":
|
||||
return shortSummary ? `Deleted ${shortSummary}` : "File deleted";
|
||||
case "file-list":
|
||||
return "Listed files";
|
||||
case "search":
|
||||
return shortSummary
|
||||
? `Searched for "${shortSummary}"`
|
||||
: "Search completed";
|
||||
case "edit":
|
||||
return shortSummary ? `Edited ${shortSummary}` : "Edit completed";
|
||||
case "todo":
|
||||
return "Updated task list";
|
||||
default:
|
||||
return `${formatToolName(toolName)} completed`;
|
||||
}
|
||||
}
|
||||
case "output-error": {
|
||||
switch (category) {
|
||||
case "bash":
|
||||
return "Command failed";
|
||||
case "web":
|
||||
return toolName === "WebSearch" ? "Search failed" : "Fetch failed";
|
||||
default:
|
||||
return `${formatToolName(toolName)} failed`;
|
||||
}
|
||||
}
|
||||
default:
|
||||
return `Running ${formatToolName(toolName)}…`;
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Output parsing helpers */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function parseOutput(output: unknown): Record<string, unknown> | null {
|
||||
if (!output) return null;
|
||||
if (typeof output === "object") return output as Record<string, unknown>;
|
||||
if (typeof output === "string") {
|
||||
const trimmed = output.trim();
|
||||
if (!trimmed) return null;
|
||||
try {
|
||||
const parsed = JSON.parse(trimmed);
|
||||
if (typeof parsed === "object" && parsed !== null) return parsed;
|
||||
} catch {
|
||||
// Return as a message wrapper for plain text output
|
||||
return { _raw: trimmed };
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract text from MCP-style content blocks.
|
||||
* SDK built-in tools (WebSearch, etc.) may return `{content: [{type:"text", text:"..."}]}`.
|
||||
*/
|
||||
function extractMcpText(output: Record<string, unknown>): string | null {
|
||||
if (Array.isArray(output.content)) {
|
||||
const texts = (output.content as Array<Record<string, unknown>>)
|
||||
.filter((b) => b.type === "text" && typeof b.text === "string")
|
||||
.map((b) => b.text as string);
|
||||
if (texts.length > 0) return texts.join("\n");
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function getExitCode(output: unknown): number | null {
|
||||
const parsed = parseOutput(output);
|
||||
if (!parsed) return null;
|
||||
if (typeof parsed.exit_code === "number") return parsed.exit_code;
|
||||
return null;
|
||||
}
|
||||
|
||||
function getStringField(
|
||||
obj: Record<string, unknown>,
|
||||
...keys: string[]
|
||||
): string | null {
|
||||
for (const key of keys) {
|
||||
if (typeof obj[key] === "string" && obj[key].length > 0)
|
||||
return obj[key] as string;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Accordion content per tool category */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
interface AccordionData {
|
||||
title: string;
|
||||
description?: string;
|
||||
content: React.ReactNode;
|
||||
}
|
||||
|
||||
function getBashAccordionData(
|
||||
input: unknown,
|
||||
output: Record<string, unknown>,
|
||||
): AccordionData {
|
||||
const inp = (input && typeof input === "object" ? input : {}) as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
const command = typeof inp.command === "string" ? inp.command : "Command";
|
||||
|
||||
const stdout = getStringField(output, "stdout");
|
||||
const stderr = getStringField(output, "stderr");
|
||||
const exitCode =
|
||||
typeof output.exit_code === "number" ? output.exit_code : null;
|
||||
const timedOut = output.timed_out === true;
|
||||
const message = getStringField(output, "message");
|
||||
|
||||
const title = timedOut
|
||||
? "Command timed out"
|
||||
: exitCode !== null && exitCode !== 0
|
||||
? `Command failed (exit ${exitCode})`
|
||||
: "Command output";
|
||||
|
||||
return {
|
||||
title,
|
||||
description: truncate(command, 80),
|
||||
content: (
|
||||
<div className="space-y-2">
|
||||
{stdout && (
|
||||
<div>
|
||||
<p className="mb-1 text-xs font-medium text-slate-500">stdout</p>
|
||||
<ContentCodeBlock>{truncate(stdout, 2000)}</ContentCodeBlock>
|
||||
</div>
|
||||
)}
|
||||
{stderr && (
|
||||
<div>
|
||||
<p className="mb-1 text-xs font-medium text-slate-500">stderr</p>
|
||||
<ContentCodeBlock>{truncate(stderr, 1000)}</ContentCodeBlock>
|
||||
</div>
|
||||
)}
|
||||
{!stdout && !stderr && message && (
|
||||
<ContentMessage>{message}</ContentMessage>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
function getWebAccordionData(
|
||||
input: unknown,
|
||||
output: Record<string, unknown>,
|
||||
): AccordionData {
|
||||
const inp = (input && typeof input === "object" ? input : {}) as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
const url =
|
||||
getStringField(inp as Record<string, unknown>, "url", "query") ??
|
||||
"Web content";
|
||||
|
||||
// Try direct string fields first, then MCP content blocks, then raw JSON
|
||||
let content = getStringField(output, "content", "text", "_raw");
|
||||
if (!content) content = extractMcpText(output);
|
||||
if (!content) {
|
||||
// Fallback: render the raw JSON so the accordion isn't empty
|
||||
try {
|
||||
const raw = JSON.stringify(output, null, 2);
|
||||
if (raw !== "{}") content = raw;
|
||||
} catch {
|
||||
/* empty */
|
||||
}
|
||||
}
|
||||
|
||||
const statusCode =
|
||||
typeof output.status_code === "number" ? output.status_code : null;
|
||||
const message = getStringField(output, "message");
|
||||
|
||||
return {
|
||||
title: statusCode ? `Response (${statusCode})` : "Search results",
|
||||
description: truncate(url, 80),
|
||||
content: content ? (
|
||||
<ContentCodeBlock>{truncate(content, 2000)}</ContentCodeBlock>
|
||||
) : message ? (
|
||||
<ContentMessage>{message}</ContentMessage>
|
||||
) : null,
|
||||
};
|
||||
}
|
||||
|
||||
function getFileAccordionData(
|
||||
input: unknown,
|
||||
output: Record<string, unknown>,
|
||||
): AccordionData {
|
||||
const inp = (input && typeof input === "object" ? input : {}) as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
const filePath =
|
||||
getStringField(
|
||||
inp as Record<string, unknown>,
|
||||
"file_path",
|
||||
"path",
|
||||
"pattern",
|
||||
) ?? "File";
|
||||
const content = getStringField(output, "content", "text", "_raw");
|
||||
const message = getStringField(output, "message");
|
||||
// For Glob/list results, try to show file list
|
||||
const files = Array.isArray(output.files) ? (output.files as string[]) : null;
|
||||
|
||||
return {
|
||||
title: message ?? "File output",
|
||||
description: truncate(filePath, 80),
|
||||
content: (
|
||||
<div className="space-y-2">
|
||||
{content && (
|
||||
<ContentCodeBlock>{truncate(content, 2000)}</ContentCodeBlock>
|
||||
)}
|
||||
{files && files.length > 0 && (
|
||||
<ContentCodeBlock>{files.join("\n")}</ContentCodeBlock>
|
||||
)}
|
||||
{!content && !files && message && (
|
||||
<ContentMessage>{message}</ContentMessage>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
interface TodoItem {
|
||||
content: string;
|
||||
status: "pending" | "in_progress" | "completed";
|
||||
activeForm?: string;
|
||||
}
|
||||
|
||||
function getTodoAccordionData(input: unknown): AccordionData {
|
||||
const inp = (input && typeof input === "object" ? input : {}) as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
const todos: TodoItem[] = Array.isArray(inp.todos)
|
||||
? inp.todos.filter(
|
||||
(t: unknown): t is TodoItem =>
|
||||
typeof t === "object" &&
|
||||
t !== null &&
|
||||
typeof (t as TodoItem).content === "string",
|
||||
)
|
||||
: [];
|
||||
|
||||
const completed = todos.filter((t) => t.status === "completed").length;
|
||||
const total = todos.length;
|
||||
|
||||
return {
|
||||
title: "Task list",
|
||||
description: `${completed}/${total} completed`,
|
||||
content: (
|
||||
<div className="space-y-1 py-1">
|
||||
{todos.map((todo, i) => (
|
||||
<div key={i} className="flex items-start gap-2 text-xs">
|
||||
<span className="mt-0.5 flex-shrink-0">
|
||||
{todo.status === "completed" ? (
|
||||
<CheckCircleIcon
|
||||
size={14}
|
||||
weight="fill"
|
||||
className="text-green-500"
|
||||
/>
|
||||
) : todo.status === "in_progress" ? (
|
||||
<CircleDashedIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="text-blue-500"
|
||||
/>
|
||||
) : (
|
||||
<CircleIcon
|
||||
size={14}
|
||||
weight="regular"
|
||||
className="text-neutral-400"
|
||||
/>
|
||||
)}
|
||||
</span>
|
||||
<span
|
||||
className={
|
||||
todo.status === "completed"
|
||||
? "text-muted-foreground line-through"
|
||||
: todo.status === "in_progress"
|
||||
? "font-medium text-foreground"
|
||||
: "text-muted-foreground"
|
||||
}
|
||||
>
|
||||
{todo.content}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
function getDefaultAccordionData(
|
||||
output: Record<string, unknown>,
|
||||
): AccordionData {
|
||||
const message = getStringField(output, "message");
|
||||
const raw = output._raw;
|
||||
const mcpText = extractMcpText(output);
|
||||
|
||||
let displayContent: string;
|
||||
if (typeof raw === "string") {
|
||||
displayContent = raw;
|
||||
} else if (mcpText) {
|
||||
displayContent = mcpText;
|
||||
} else if (message) {
|
||||
displayContent = message;
|
||||
} else {
|
||||
try {
|
||||
displayContent = JSON.stringify(output, null, 2);
|
||||
} catch {
|
||||
displayContent = String(output);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
title: "Output",
|
||||
description: message ?? undefined,
|
||||
content: (
|
||||
<ContentCodeBlock>{truncate(displayContent, 2000)}</ContentCodeBlock>
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
function getAccordionData(
|
||||
category: ToolCategory,
|
||||
input: unknown,
|
||||
output: Record<string, unknown>,
|
||||
): AccordionData {
|
||||
switch (category) {
|
||||
case "bash":
|
||||
return getBashAccordionData(input, output);
|
||||
case "web":
|
||||
return getWebAccordionData(input, output);
|
||||
case "file-read":
|
||||
case "file-write":
|
||||
case "file-delete":
|
||||
case "file-list":
|
||||
case "search":
|
||||
case "edit":
|
||||
return getFileAccordionData(input, output);
|
||||
case "todo":
|
||||
return getTodoAccordionData(input);
|
||||
default:
|
||||
return getDefaultAccordionData(output);
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Component */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
export function GenericTool({ part }: Props) {
|
||||
const toolName = extractToolName(part);
|
||||
const category = getToolCategory(toolName);
|
||||
const isStreaming =
|
||||
part.state === "input-streaming" || part.state === "input-available";
|
||||
const isError = part.state === "output-error";
|
||||
const text = getAnimationText(part, category);
|
||||
|
||||
const output = parseOutput(part.output);
|
||||
const hasOutput =
|
||||
part.state === "output-available" &&
|
||||
!!output &&
|
||||
Object.keys(output).length > 0;
|
||||
const hasError = isError && !!output;
|
||||
|
||||
// TodoWrite: always show accordion from input (the todo list lives in input)
|
||||
const hasTodoInput =
|
||||
category === "todo" &&
|
||||
part.input &&
|
||||
typeof part.input === "object" &&
|
||||
Array.isArray((part.input as Record<string, unknown>).todos);
|
||||
|
||||
const showAccordion = hasOutput || hasError || hasTodoInput;
|
||||
const accordionData = showAccordion
|
||||
? getAccordionData(category, part.input, output ?? {})
|
||||
: null;
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<GearIcon
|
||||
size={14}
|
||||
weight="regular"
|
||||
className={
|
||||
isError
|
||||
? "text-red-500"
|
||||
: isStreaming
|
||||
? "animate-spin text-neutral-500"
|
||||
: "text-neutral-400"
|
||||
}
|
||||
<ToolIcon
|
||||
category={category}
|
||||
isStreaming={isStreaming}
|
||||
isError={isError}
|
||||
/>
|
||||
<MorphingTextAnimation
|
||||
text={getAnimationText(part)}
|
||||
text={text}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{showAccordion && accordionData ? (
|
||||
<ToolAccordion
|
||||
icon={<AccordionIcon category={category} />}
|
||||
title={accordionData.title}
|
||||
description={accordionData.description}
|
||||
titleClassName={isError ? "text-red-500" : undefined}
|
||||
>
|
||||
{accordionData.content}
|
||||
</ToolAccordion>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -80,6 +80,12 @@ export function useCopilotPage() {
|
||||
},
|
||||
};
|
||||
},
|
||||
// Resume: GET goes to the same URL as POST (backend uses
|
||||
// method to distinguish). Override the default formula which
|
||||
// would append /{chatId}/stream to the existing path.
|
||||
prepareReconnectToStreamRequest: () => ({
|
||||
api: `/api/chat/sessions/${sessionId}/stream`,
|
||||
}),
|
||||
})
|
||||
: null,
|
||||
[sessionId],
|
||||
@@ -88,6 +94,7 @@ export function useCopilotPage() {
|
||||
const { messages, sendMessage, stop, status, error, setMessages } = useChat({
|
||||
id: sessionId ?? undefined,
|
||||
transport: transport ?? undefined,
|
||||
resume: true,
|
||||
});
|
||||
|
||||
// Abort the stream if the backend doesn't start sending data within 12s.
|
||||
|
||||
Reference in New Issue
Block a user