mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-19 02:54:28 -05:00
Compare commits
1 Commits
copilot/sd
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d2e545ee48 |
@@ -4,6 +4,7 @@ This module contains the CoPilotExecutor class that consumes chat tasks from
|
||||
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
@@ -163,23 +164,21 @@ class CoPilotExecutor(AppProcess):
|
||||
self._cancel_thread, self.cancel_client, "[cleanup][cancel]"
|
||||
)
|
||||
|
||||
# Clean up worker threads (closes per-loop workspace storage sessions)
|
||||
# Shutdown executor
|
||||
if self._executor:
|
||||
from .processor import cleanup_worker
|
||||
|
||||
logger.info(f"[cleanup {pid}] Cleaning up workers...")
|
||||
futures = []
|
||||
for _ in range(self._executor._max_workers):
|
||||
futures.append(self._executor.submit(cleanup_worker))
|
||||
for f in futures:
|
||||
try:
|
||||
f.result(timeout=10)
|
||||
except Exception as e:
|
||||
logger.warning(f"[cleanup {pid}] Worker cleanup error: {e}")
|
||||
|
||||
logger.info(f"[cleanup {pid}] Shutting down executor...")
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
# Close async resources (workspace storage aiohttp session, etc.)
|
||||
try:
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(shutdown_workspace_storage())
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[cleanup {pid}] Error closing workspace storage: {e}")
|
||||
|
||||
# Release any remaining locks
|
||||
for task_id, lock in list(self._task_locks.items()):
|
||||
try:
|
||||
|
||||
@@ -60,18 +60,6 @@ def init_worker():
|
||||
_tls.processor.on_executor_start()
|
||||
|
||||
|
||||
def cleanup_worker():
|
||||
"""Clean up the processor for the current worker thread.
|
||||
|
||||
Should be called before the worker thread's event loop is destroyed so
|
||||
that event-loop-bound resources (e.g. ``aiohttp.ClientSession``) are
|
||||
closed on the correct loop.
|
||||
"""
|
||||
processor: CoPilotProcessor | None = getattr(_tls, "processor", None)
|
||||
if processor is not None:
|
||||
processor.cleanup()
|
||||
|
||||
|
||||
# ============ Processor Class ============ #
|
||||
|
||||
|
||||
@@ -110,28 +98,6 @@ class CoPilotProcessor:
|
||||
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up event-loop-bound resources before the loop is destroyed.
|
||||
|
||||
Shuts down the workspace storage instance that belongs to this
|
||||
worker's event loop, ensuring ``aiohttp.ClientSession.close()``
|
||||
runs on the same loop that created the session.
|
||||
"""
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
shutdown_workspace_storage(), self.execution_loop
|
||||
)
|
||||
future.result(timeout=5)
|
||||
except Exception as e:
|
||||
logger.warning(f"[CoPilotExecutor] Worker {self.tid} cleanup error: {e}")
|
||||
|
||||
# Stop the event loop
|
||||
self.execution_loop.call_soon_threadsafe(self.execution_loop.stop)
|
||||
self.execution_thread.join(timeout=5)
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} cleaned up")
|
||||
|
||||
@error_logged(swallow=False)
|
||||
def execute(
|
||||
self,
|
||||
|
||||
@@ -53,7 +53,6 @@ class SDKResponseAdapter:
|
||||
self.has_started_text = False
|
||||
self.has_ended_text = False
|
||||
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
||||
self.resolved_tool_calls: set[str] = set()
|
||||
self.task_id: str | None = None
|
||||
self.step_open = False
|
||||
|
||||
@@ -75,10 +74,6 @@ class SDKResponseAdapter:
|
||||
self.step_open = True
|
||||
|
||||
elif isinstance(sdk_message, AssistantMessage):
|
||||
# Flush any SDK built-in tool calls that didn't get a UserMessage
|
||||
# result (e.g. WebSearch, Read handled internally by the CLI).
|
||||
self._flush_unresolved_tool_calls(responses)
|
||||
|
||||
# After tool results, the SDK sends a new AssistantMessage for the
|
||||
# next LLM turn. Open a new step if the previous one was closed.
|
||||
if not self.step_open:
|
||||
@@ -116,8 +111,6 @@ class SDKResponseAdapter:
|
||||
# UserMessage carries tool results back from tool execution.
|
||||
content = sdk_message.content
|
||||
blocks = content if isinstance(content, list) else []
|
||||
resolved_in_blocks: set[str] = set()
|
||||
|
||||
for block in blocks:
|
||||
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
||||
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
|
||||
@@ -139,37 +132,6 @@ class SDKResponseAdapter:
|
||||
success=not (block.is_error or False),
|
||||
)
|
||||
)
|
||||
resolved_in_blocks.add(block.tool_use_id)
|
||||
|
||||
# Handle SDK built-in tool results carried via parent_tool_use_id
|
||||
# instead of (or in addition to) ToolResultBlock content.
|
||||
parent_id = sdk_message.parent_tool_use_id
|
||||
if parent_id and parent_id not in resolved_in_blocks:
|
||||
tool_info = self.current_tool_calls.get(parent_id, {})
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
|
||||
# Try stashed output first (from PostToolUse hook),
|
||||
# then tool_use_result dict, then string content.
|
||||
output = pop_pending_tool_output(tool_name)
|
||||
if not output:
|
||||
tur = sdk_message.tool_use_result
|
||||
if tur is not None:
|
||||
output = _extract_tool_use_result(tur)
|
||||
if not output and isinstance(content, str) and content.strip():
|
||||
output = content.strip()
|
||||
|
||||
if output:
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=parent_id,
|
||||
toolName=tool_name,
|
||||
output=output,
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
resolved_in_blocks.add(parent_id)
|
||||
|
||||
self.resolved_tool_calls.update(resolved_in_blocks)
|
||||
|
||||
# Close the current step after tool results — the next
|
||||
# AssistantMessage will open a new step for the continuation.
|
||||
@@ -178,7 +140,6 @@ class SDKResponseAdapter:
|
||||
self.step_open = False
|
||||
|
||||
elif isinstance(sdk_message, ResultMessage):
|
||||
self._flush_unresolved_tool_calls(responses)
|
||||
self._end_text_if_open(responses)
|
||||
# Close the step before finishing.
|
||||
if self.step_open:
|
||||
@@ -188,7 +149,7 @@ class SDKResponseAdapter:
|
||||
if sdk_message.subtype == "success":
|
||||
responses.append(StreamFinish())
|
||||
elif sdk_message.subtype in ("error", "error_during_execution"):
|
||||
error_msg = sdk_message.result or "Unknown error"
|
||||
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
|
||||
responses.append(
|
||||
StreamError(errorText=str(error_msg), code="sdk_error")
|
||||
)
|
||||
@@ -219,59 +180,6 @@ class SDKResponseAdapter:
|
||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||
self.has_ended_text = True
|
||||
|
||||
def _flush_unresolved_tool_calls(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""Emit outputs for tool calls that didn't receive a UserMessage result.
|
||||
|
||||
SDK built-in tools (WebSearch, Read, etc.) may be executed by the CLI
|
||||
internally without surfacing a separate ``UserMessage`` with
|
||||
``ToolResultBlock`` content. The ``PostToolUse`` hook stashes their
|
||||
output, which we pop and emit here before the next ``AssistantMessage``
|
||||
starts.
|
||||
"""
|
||||
flushed = False
|
||||
for tool_id, tool_info in self.current_tool_calls.items():
|
||||
if tool_id in self.resolved_tool_calls:
|
||||
continue
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
output = pop_pending_tool_output(tool_name)
|
||||
if output is not None:
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_id,
|
||||
toolName=tool_name,
|
||||
output=output,
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
self.resolved_tool_calls.add(tool_id)
|
||||
flushed = True
|
||||
logger.debug(
|
||||
f"Flushed pending output for built-in tool {tool_name} "
|
||||
f"(call {tool_id})"
|
||||
)
|
||||
else:
|
||||
# No output available — emit an empty output so the frontend
|
||||
# transitions the tool from input-available to output-available
|
||||
# (stops the spinner).
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_id,
|
||||
toolName=tool_name,
|
||||
output="",
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
self.resolved_tool_calls.add(tool_id)
|
||||
flushed = True
|
||||
logger.debug(
|
||||
f"Flushed empty output for unresolved tool {tool_name} "
|
||||
f"(call {tool_id})"
|
||||
)
|
||||
|
||||
if flushed and self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
|
||||
|
||||
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
||||
"""Extract a string output from a ToolResultBlock's content field."""
|
||||
@@ -291,30 +199,3 @@ def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
||||
return json.dumps(content)
|
||||
except (TypeError, ValueError):
|
||||
return str(content)
|
||||
|
||||
|
||||
def _extract_tool_use_result(result: object) -> str:
|
||||
"""Extract a string from a UserMessage's ``tool_use_result`` dict.
|
||||
|
||||
SDK built-in tools may store their result in ``tool_use_result``
|
||||
instead of (or in addition to) ``ToolResultBlock`` content blocks.
|
||||
"""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
if isinstance(result, dict):
|
||||
# Try common result keys
|
||||
for key in ("content", "text", "output", "stdout", "result"):
|
||||
val = result.get(key)
|
||||
if isinstance(val, str) and val:
|
||||
return val
|
||||
# Fall back to JSON serialization of the whole dict
|
||||
try:
|
||||
return json.dumps(result)
|
||||
except (TypeError, ValueError):
|
||||
return str(result)
|
||||
if result is None:
|
||||
return ""
|
||||
try:
|
||||
return json.dumps(result)
|
||||
except (TypeError, ValueError):
|
||||
return str(result)
|
||||
|
||||
@@ -16,7 +16,6 @@ from .tool_adapter import (
|
||||
DANGEROUS_PATTERNS,
|
||||
MCP_TOOL_PREFIX,
|
||||
WORKSPACE_SCOPED_TOOLS,
|
||||
stash_pending_tool_output,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -225,25 +224,10 @@ def create_security_hooks(
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log successful tool executions and stash SDK built-in tool outputs.
|
||||
|
||||
MCP tools stash their output in ``_execute_tool_sync`` before the
|
||||
SDK can truncate it. SDK built-in tools (WebSearch, Read, etc.)
|
||||
are executed by the CLI internally — this hook captures their
|
||||
output so the response adapter can forward it to the frontend.
|
||||
"""
|
||||
"""Log successful tool executions for observability."""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
||||
|
||||
# Stash output for SDK built-in tools so the response adapter can
|
||||
# emit StreamToolOutputAvailable even when the CLI doesn't surface
|
||||
# a separate UserMessage with ToolResultBlock content.
|
||||
if not tool_name.startswith(MCP_TOOL_PREFIX):
|
||||
tool_response = input_data.get("tool_response")
|
||||
if tool_response is not None:
|
||||
stash_pending_tool_output(tool_name, tool_response)
|
||||
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def post_tool_failure_hook(
|
||||
|
||||
@@ -47,7 +47,6 @@ from .tool_adapter import (
|
||||
set_execution_context,
|
||||
)
|
||||
from .transcript import (
|
||||
cleanup_cli_project_dir,
|
||||
download_transcript,
|
||||
read_transcript_file,
|
||||
upload_transcript,
|
||||
@@ -87,12 +86,9 @@ _SDK_TOOL_SUPPLEMENT = """
|
||||
for shell commands — it runs in a network-isolated sandbox.
|
||||
- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the
|
||||
same working directory. Files created by one are readable by the other.
|
||||
- **IMPORTANT — File persistence**: Your working directory is **ephemeral** —
|
||||
files are lost between turns. When you create or modify important files
|
||||
(code, configs, outputs), you MUST save them using `write_workspace_file`
|
||||
so they persist. Use `read_workspace_file` and `list_workspace_files` to
|
||||
access files saved in previous turns. If a "Files from previous turns"
|
||||
section is present above, those files are available via `read_workspace_file`.
|
||||
These files are **ephemeral** — they exist only for the current session.
|
||||
- **Persistent storage**: Use `write_workspace_file` / `read_workspace_file`
|
||||
for files that should persist across sessions (stored in cloud storage).
|
||||
- Long-running tools (create_agent, edit_agent, etc.) are handled
|
||||
asynchronously. You will receive an immediate response; the actual result
|
||||
is delivered to the user via a background stream.
|
||||
@@ -272,28 +268,48 @@ def _make_sdk_cwd(session_id: str) -> str:
|
||||
|
||||
|
||||
def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||
"""Remove SDK session artifacts for a specific working directory.
|
||||
"""Remove SDK tool-result files for a specific session working directory.
|
||||
|
||||
Cleans up:
|
||||
- ``~/.claude/projects/<encoded-cwd>/`` — CLI session transcripts and
|
||||
tool-result files. Each SDK turn uses a unique cwd, so this directory
|
||||
is safe to remove entirely.
|
||||
- ``/tmp/copilot-<session>/`` — the ephemeral working directory.
|
||||
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
|
||||
We clean only the specific cwd's results to avoid race conditions between
|
||||
concurrent sessions.
|
||||
|
||||
Security: *cwd* MUST be created by ``_make_sdk_cwd()`` which sanitizes
|
||||
the session_id.
|
||||
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
# Validate cwd is under the expected prefix
|
||||
normalized = os.path.normpath(cwd)
|
||||
if not normalized.startswith(_SDK_CWD_PREFIX):
|
||||
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
|
||||
return
|
||||
|
||||
# Clean the CLI's project directory (transcripts + tool-results).
|
||||
cleanup_cli_project_dir(cwd)
|
||||
# SDK encodes the cwd path by replacing '/' with '-'
|
||||
encoded_cwd = normalized.replace("/", "-")
|
||||
|
||||
# Clean up the temp cwd directory itself.
|
||||
# Construct the project directory path (known-safe home expansion)
|
||||
claude_projects = os.path.expanduser("~/.claude/projects")
|
||||
project_dir = os.path.join(claude_projects, encoded_cwd)
|
||||
|
||||
# Security check 3: Validate project_dir is under ~/.claude/projects
|
||||
project_dir = os.path.normpath(project_dir)
|
||||
if not project_dir.startswith(claude_projects):
|
||||
logger.warning(
|
||||
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
|
||||
)
|
||||
return
|
||||
|
||||
results_dir = os.path.join(project_dir, "tool-results")
|
||||
if os.path.isdir(results_dir):
|
||||
for filename in os.listdir(results_dir):
|
||||
file_path = os.path.join(results_dir, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Also clean up the temp cwd directory itself
|
||||
try:
|
||||
shutil.rmtree(normalized, ignore_errors=True)
|
||||
except OSError:
|
||||
@@ -503,7 +519,6 @@ async def stream_chat_completion_sdk(
|
||||
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
|
||||
captured_transcript.path = transcript_path
|
||||
captured_transcript.sdk_session_id = sdk_session_id
|
||||
logger.debug(f"[SDK] Stop hook: path={transcript_path!r}")
|
||||
|
||||
security_hooks = create_security_hooks(
|
||||
user_id,
|
||||
@@ -515,20 +530,18 @@ async def stream_chat_completion_sdk(
|
||||
# --- Resume strategy: download transcript from bucket ---
|
||||
resume_file: str | None = None
|
||||
use_resume = False
|
||||
transcript_msg_count = 0 # watermark: session.messages length at upload
|
||||
|
||||
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
dl = await download_transcript(user_id, session_id)
|
||||
if dl and validate_transcript(dl.content):
|
||||
transcript_content = await download_transcript(user_id, session_id)
|
||||
if transcript_content and validate_transcript(transcript_content):
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
dl.content, session_id, sdk_cwd
|
||||
transcript_content, session_id, sdk_cwd
|
||||
)
|
||||
if resume_file:
|
||||
use_resume = True
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.debug(
|
||||
f"[SDK] Using --resume ({len(dl.content)}B, "
|
||||
f"msg_count={transcript_msg_count})"
|
||||
logger.info(
|
||||
f"[SDK] Using --resume with transcript "
|
||||
f"({len(transcript_content)} bytes)"
|
||||
)
|
||||
|
||||
sdk_options_kwargs: dict[str, Any] = {
|
||||
@@ -569,35 +582,11 @@ async def stream_chat_completion_sdk(
|
||||
# Build query: with --resume the CLI already has full
|
||||
# context, so we only send the new message. Without
|
||||
# resume, compress history into a context prefix.
|
||||
#
|
||||
# Hybrid mode: if the transcript is stale (upload missed
|
||||
# some turns), compress only the gap and prepend it so
|
||||
# the agent has transcript context + missed turns.
|
||||
query_message = current_message
|
||||
current_msg_count = len(session.messages)
|
||||
|
||||
if use_resume and transcript_msg_count >= 0:
|
||||
# Transcript covers messages[0..M-1]. Current session
|
||||
# has N messages (last one is the new user msg).
|
||||
# Gap = messages[M .. N-2] (everything between upload
|
||||
# and the current turn).
|
||||
if transcript_msg_count < current_msg_count - 1:
|
||||
gap = session.messages[transcript_msg_count:-1]
|
||||
gap_context = _format_conversation_context(gap)
|
||||
if gap_context:
|
||||
logger.info(
|
||||
f"[SDK] Transcript stale: covers {transcript_msg_count} "
|
||||
f"of {current_msg_count} messages, compressing "
|
||||
f"{len(gap)} missed messages"
|
||||
)
|
||||
query_message = (
|
||||
f"{gap_context}\n\n"
|
||||
f"Now, the user says:\n{current_message}"
|
||||
)
|
||||
elif not use_resume and current_msg_count > 1:
|
||||
if not use_resume and len(session.messages) > 1:
|
||||
logger.warning(
|
||||
f"[SDK] Using compression fallback for session "
|
||||
f"{session_id} ({current_msg_count} messages) — "
|
||||
f"{session_id} ({len(session.messages)} messages) — "
|
||||
f"no transcript available for --resume"
|
||||
)
|
||||
compressed = await _compress_conversation_history(session)
|
||||
@@ -608,10 +597,10 @@ async def stream_chat_completion_sdk(
|
||||
f"Now, the user says:\n{current_message}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"[SDK] Sending query ({len(session.messages)} msgs, "
|
||||
f"resume={use_resume})"
|
||||
logger.info(
|
||||
f"[SDK] Sending query ({len(session.messages)} msgs in session)"
|
||||
)
|
||||
logger.debug(f"[SDK] Query preview: {current_message[:80]!r}")
|
||||
await client.query(query_message, session_id=session_id)
|
||||
|
||||
assistant_response = ChatMessage(role="assistant", content="")
|
||||
@@ -692,33 +681,25 @@ async def stream_chat_completion_sdk(
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
# --- Upload transcript for next-turn --resume ---
|
||||
# After async with the SDK task group has exited, so the Stop
|
||||
# hook has already fired and the CLI has been SIGTERMed. The
|
||||
# CLI uses appendFileSync, so all writes are safely on disk.
|
||||
if config.claude_agent_use_resume and user_id:
|
||||
# With --resume the CLI appends to the resume file (most
|
||||
# complete). Otherwise use the Stop hook path.
|
||||
if use_resume and resume_file:
|
||||
raw_transcript = read_transcript_file(resume_file)
|
||||
elif captured_transcript.path:
|
||||
# --- Capture transcript while CLI is still alive ---
|
||||
# Must happen INSIDE async with: close() sends SIGTERM
|
||||
# which kills the CLI before it can flush the JSONL.
|
||||
if (
|
||||
config.claude_agent_use_resume
|
||||
and user_id
|
||||
and captured_transcript.available
|
||||
):
|
||||
# Give CLI time to flush JSONL writes before we read
|
||||
await asyncio.sleep(0.5)
|
||||
raw_transcript = read_transcript_file(captured_transcript.path)
|
||||
else:
|
||||
raw_transcript = None
|
||||
|
||||
if raw_transcript:
|
||||
# Shield the upload from generator cancellation so a
|
||||
# client disconnect / page refresh doesn't lose the
|
||||
# transcript. The upload must finish even if the SSE
|
||||
# connection is torn down.
|
||||
await asyncio.shield(
|
||||
_try_upload_transcript(
|
||||
user_id,
|
||||
session_id,
|
||||
raw_transcript,
|
||||
message_count=len(session.messages),
|
||||
if raw_transcript:
|
||||
task = asyncio.create_task(
|
||||
_upload_transcript_bg(user_id, session_id, raw_transcript)
|
||||
)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
else:
|
||||
logger.debug("[SDK] Stop hook fired but transcript not usable")
|
||||
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
@@ -727,7 +708,7 @@ async def stream_chat_completion_sdk(
|
||||
"to use the OpenAI-compatible fallback."
|
||||
)
|
||||
|
||||
await asyncio.shield(upsert_chat_session(session))
|
||||
await upsert_chat_session(session)
|
||||
logger.debug(
|
||||
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
|
||||
)
|
||||
@@ -737,7 +718,7 @@ async def stream_chat_completion_sdk(
|
||||
except Exception as e:
|
||||
logger.error(f"[SDK] Error: {e}", exc_info=True)
|
||||
try:
|
||||
await asyncio.shield(upsert_chat_session(session))
|
||||
await upsert_chat_session(session)
|
||||
except Exception as save_err:
|
||||
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
||||
yield StreamError(
|
||||
@@ -750,31 +731,14 @@ async def stream_chat_completion_sdk(
|
||||
_cleanup_sdk_tool_results(sdk_cwd)
|
||||
|
||||
|
||||
async def _try_upload_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
raw_content: str,
|
||||
message_count: int = 0,
|
||||
) -> bool:
|
||||
"""Strip progress entries and upload transcript (with timeout).
|
||||
|
||||
Returns True if the upload completed without error.
|
||||
"""
|
||||
async def _upload_transcript_bg(
|
||||
user_id: str, session_id: str, raw_content: str
|
||||
) -> None:
|
||||
"""Background task to strip progress entries and upload transcript."""
|
||||
try:
|
||||
async with asyncio.timeout(30):
|
||||
await upload_transcript(
|
||||
user_id, session_id, raw_content, message_count=message_count
|
||||
)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[SDK] Transcript upload timed out for {session_id}")
|
||||
return False
|
||||
await upload_transcript(user_id, session_id, raw_content)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[SDK] Failed to upload transcript for {session_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
logger.error(f"[SDK] Failed to upload transcript for {session_id}: {e}")
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
|
||||
@@ -41,7 +41,7 @@ _current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
_pending_tool_outputs: ContextVar[dict[str, list[str]]] = ContextVar(
|
||||
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar(
|
||||
"pending_tool_outputs", default=None # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
@@ -88,52 +88,19 @@ def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
|
||||
|
||||
def pop_pending_tool_output(tool_name: str) -> str | None:
|
||||
"""Pop and return the oldest stashed output for *tool_name*.
|
||||
"""Pop and return the stashed full output for *tool_name*.
|
||||
|
||||
The SDK CLI may truncate large tool results (writing them to disk and
|
||||
replacing the content with a file reference). This stash keeps the
|
||||
original MCP output so the response adapter can forward it to the
|
||||
frontend for proper widget rendering.
|
||||
|
||||
Uses a FIFO queue per tool name so duplicate calls to the same tool
|
||||
in one turn each get their own output.
|
||||
|
||||
Returns ``None`` if nothing was stashed for *tool_name*.
|
||||
"""
|
||||
pending = _pending_tool_outputs.get(None)
|
||||
if pending is None:
|
||||
return None
|
||||
queue = pending.get(tool_name)
|
||||
if not queue:
|
||||
pending.pop(tool_name, None)
|
||||
return None
|
||||
value = queue.pop(0)
|
||||
if not queue:
|
||||
del pending[tool_name]
|
||||
return value
|
||||
|
||||
|
||||
def stash_pending_tool_output(tool_name: str, output: Any) -> None:
|
||||
"""Stash tool output for later retrieval by the response adapter.
|
||||
|
||||
Used by the PostToolUse hook to capture SDK built-in tool outputs
|
||||
(WebSearch, Read, etc.) that aren't available through the MCP stash
|
||||
mechanism in ``_execute_tool_sync``.
|
||||
|
||||
Appends to a FIFO queue per tool name so multiple calls to the same
|
||||
tool in one turn are all preserved.
|
||||
"""
|
||||
pending = _pending_tool_outputs.get(None)
|
||||
if pending is None:
|
||||
return
|
||||
if isinstance(output, str):
|
||||
text = output
|
||||
else:
|
||||
try:
|
||||
text = json.dumps(output)
|
||||
except (TypeError, ValueError):
|
||||
text = str(output)
|
||||
pending.setdefault(tool_name, []).append(text)
|
||||
return pending.pop(tool_name, None)
|
||||
|
||||
|
||||
async def _execute_tool_sync(
|
||||
@@ -158,63 +125,14 @@ async def _execute_tool_sync(
|
||||
# Stash the full output before the SDK potentially truncates it.
|
||||
pending = _pending_tool_outputs.get(None)
|
||||
if pending is not None:
|
||||
pending.setdefault(base_tool.name, []).append(text)
|
||||
|
||||
content_blocks: list[dict[str, str]] = [{"type": "text", "text": text}]
|
||||
|
||||
# If the tool result contains inline image data, add an MCP image block
|
||||
# so Claude can "see" the image (e.g. read_workspace_file on a small PNG).
|
||||
image_block = _extract_image_block(text)
|
||||
if image_block:
|
||||
content_blocks.append(image_block)
|
||||
pending[base_tool.name] = text
|
||||
|
||||
return {
|
||||
"content": content_blocks,
|
||||
"content": [{"type": "text", "text": text}],
|
||||
"isError": not result.success,
|
||||
}
|
||||
|
||||
|
||||
# MIME types that Claude can process as image content blocks.
|
||||
_SUPPORTED_IMAGE_TYPES = frozenset(
|
||||
{"image/png", "image/jpeg", "image/gif", "image/webp"}
|
||||
)
|
||||
|
||||
|
||||
def _extract_image_block(text: str) -> dict[str, str] | None:
|
||||
"""Extract an MCP image content block from a tool result JSON string.
|
||||
|
||||
Detects workspace file responses with ``content_base64`` and an image
|
||||
MIME type, returning an MCP-format image block that allows Claude to
|
||||
"see" the image. Returns ``None`` if the result is not an inline image.
|
||||
"""
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
mime_type = data.get("mime_type", "")
|
||||
base64_content = data.get("content_base64", "")
|
||||
|
||||
# Only inline small images — large ones would exceed Claude's limits.
|
||||
# 32 KB raw ≈ ~43 KB base64.
|
||||
_MAX_IMAGE_BASE64_BYTES = 43_000
|
||||
if (
|
||||
mime_type in _SUPPORTED_IMAGE_TYPES
|
||||
and base64_content
|
||||
and len(base64_content) <= _MAX_IMAGE_BASE64_BYTES
|
||||
):
|
||||
return {
|
||||
"type": "image",
|
||||
"data": base64_content,
|
||||
"mimeType": mime_type,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _mcp_error(message: str) -> dict[str, Any]:
|
||||
return {
|
||||
"content": [
|
||||
@@ -393,29 +311,14 @@ def create_copilot_mcp_server():
|
||||
# which provides kernel-level network isolation via unshare --net.
|
||||
# Task allows spawning sub-agents (rate-limited by security hooks).
|
||||
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
|
||||
# TodoWrite manages the task checklist shown in the UI — no security concern.
|
||||
_SDK_BUILTIN_TOOLS = [
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
"Task",
|
||||
"WebSearch",
|
||||
"TodoWrite",
|
||||
]
|
||||
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task", "WebSearch"]
|
||||
|
||||
# SDK built-in tools that must be explicitly blocked.
|
||||
# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level
|
||||
# network isolation (unshare --net) instead.
|
||||
# WebFetch: SSRF risk — can reach internal network (localhost, 10.x, etc.).
|
||||
# Agent uses the SSRF-protected mcp__copilot__web_fetch tool instead.
|
||||
# AskUserQuestion: interactive CLI tool — no terminal in copilot context.
|
||||
SDK_DISALLOWED_TOOLS = [
|
||||
"Bash",
|
||||
"WebFetch",
|
||||
"AskUserQuestion",
|
||||
]
|
||||
SDK_DISALLOWED_TOOLS = ["Bash", "WebFetch"]
|
||||
|
||||
# Tools that are blocked entirely in security hooks (defence-in-depth).
|
||||
# Includes SDK_DISALLOWED_TOOLS plus common aliases/synonyms.
|
||||
|
||||
@@ -14,8 +14,6 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,16 +31,6 @@ STRIPPABLE_TYPES = frozenset(
|
||||
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptDownload:
|
||||
"""Result of downloading a transcript with its metadata."""
|
||||
|
||||
content: str
|
||||
message_count: int = 0 # session.messages length when uploaded
|
||||
uploaded_at: float = 0.0 # epoch timestamp of upload
|
||||
|
||||
|
||||
# Workspace storage constants — deterministic path from session_id.
|
||||
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
|
||||
|
||||
@@ -131,19 +119,23 @@ def read_transcript_file(transcript_path: str) -> str | None:
|
||||
content = f.read()
|
||||
|
||||
if not content.strip():
|
||||
logger.debug(f"[Transcript] Empty file: {transcript_path}")
|
||||
return None
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
if len(lines) < 3:
|
||||
# Raw files with ≤2 lines are metadata-only
|
||||
# (queue-operation + file-history-snapshot, no conversation).
|
||||
logger.debug(
|
||||
f"[Transcript] Too few lines ({len(lines)}): {transcript_path}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Quick structural validation — parse first and last lines.
|
||||
json.loads(lines[0])
|
||||
json.loads(lines[-1])
|
||||
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[Transcript] Read {len(lines)} lines, "
|
||||
f"{len(content)} bytes from {transcript_path}"
|
||||
)
|
||||
@@ -168,41 +160,6 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The CLI replaces all non-alphanumeric characters with ``-``.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory.
|
||||
|
||||
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
|
||||
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
|
||||
safe to remove entirely after the transcript has been uploaded.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
cwd_encoded = _encode_cwd_for_cli(sdk_cwd)
|
||||
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||
|
||||
if not project_dir.startswith(projects_base + os.sep):
|
||||
logger.warning(
|
||||
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
|
||||
)
|
||||
return
|
||||
|
||||
if os.path.isdir(project_dir):
|
||||
shutil.rmtree(project_dir, ignore_errors=True)
|
||||
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
|
||||
else:
|
||||
logger.debug(f"[Transcript] Project dir not found: {project_dir}")
|
||||
|
||||
|
||||
def write_transcript_to_tempfile(
|
||||
transcript_content: str,
|
||||
session_id: str,
|
||||
@@ -234,7 +191,7 @@ def write_transcript_to_tempfile(
|
||||
with open(jsonl_path, "w") as f:
|
||||
f.write(transcript_content)
|
||||
|
||||
logger.debug(f"[Transcript] Wrote resume file: {jsonl_path}")
|
||||
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
||||
return jsonl_path
|
||||
|
||||
except OSError as e:
|
||||
@@ -291,15 +248,6 @@ def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
)
|
||||
|
||||
|
||||
def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
"""Return (workspace_id, file_id, filename) for a session's transcript metadata."""
|
||||
return (
|
||||
TRANSCRIPT_STORAGE_PREFIX,
|
||||
_sanitize_id(user_id),
|
||||
f"{_sanitize_id(session_id)}.meta.json",
|
||||
)
|
||||
|
||||
|
||||
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
"""Build the full storage path string that ``retrieve()`` expects.
|
||||
|
||||
@@ -320,30 +268,21 @@ def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
return f"local://{wid}/{fid}/{fname}"
|
||||
|
||||
|
||||
async def upload_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
content: str,
|
||||
message_count: int = 0,
|
||||
) -> None:
|
||||
async def upload_transcript(user_id: str, session_id: str, content: str) -> None:
|
||||
"""Strip progress entries and upload transcript to bucket storage.
|
||||
|
||||
Safety: only overwrites when the new (stripped) transcript is larger than
|
||||
what is already stored. Since JSONL is append-only, the latest transcript
|
||||
is always the longest. This prevents a slow/stale background task from
|
||||
clobbering a newer upload from a concurrent turn.
|
||||
|
||||
Args:
|
||||
message_count: ``len(session.messages)`` at upload time — used by
|
||||
the next turn to detect staleness and compress only the gap.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
stripped = strip_progress_entries(content)
|
||||
if not validate_transcript(stripped):
|
||||
logger.warning(
|
||||
f"[Transcript] Skipping upload — stripped content not valid "
|
||||
f"for session {session_id}"
|
||||
f"[Transcript] Skipping upload — stripped content is not a valid "
|
||||
f"transcript for session {session_id}"
|
||||
)
|
||||
return
|
||||
|
||||
@@ -357,9 +296,10 @@ async def upload_transcript(
|
||||
try:
|
||||
existing = await storage.retrieve(path)
|
||||
if len(existing) >= new_size:
|
||||
logger.debug(
|
||||
f"[Transcript] Skipping upload — existing ({len(existing)}B) "
|
||||
f">= new ({new_size}B) for session {session_id}"
|
||||
logger.info(
|
||||
f"[Transcript] Skipping upload — existing transcript "
|
||||
f"({len(existing)}B) >= new ({new_size}B) for session "
|
||||
f"{session_id}"
|
||||
)
|
||||
return
|
||||
except (FileNotFoundError, Exception):
|
||||
@@ -371,38 +311,16 @@ async def upload_transcript(
|
||||
filename=fname,
|
||||
content=encoded,
|
||||
)
|
||||
|
||||
# Store metadata alongside the transcript so the next turn can detect
|
||||
# staleness and only compress the gap instead of the full history.
|
||||
# Wrapped in try/except so a metadata write failure doesn't orphan
|
||||
# the already-uploaded transcript — the next turn will just fall back
|
||||
# to full gap fill (msg_count=0).
|
||||
try:
|
||||
meta = {"message_count": message_count, "uploaded_at": time.time()}
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
await storage.store(
|
||||
workspace_id=mwid,
|
||||
file_id=mfid,
|
||||
filename=mfname,
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transcript] Failed to write metadata for {session_id}: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Uploaded {new_size}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count}) "
|
||||
f"for session {session_id}"
|
||||
f"[Transcript] Uploaded {new_size} bytes "
|
||||
f"(stripped from {len(content)}) for session {session_id}"
|
||||
)
|
||||
|
||||
|
||||
async def download_transcript(
|
||||
user_id: str, session_id: str
|
||||
) -> TranscriptDownload | None:
|
||||
"""Download transcript and metadata from bucket storage.
|
||||
async def download_transcript(user_id: str, session_id: str) -> str | None:
|
||||
"""Download transcript from bucket storage.
|
||||
|
||||
Returns a ``TranscriptDownload`` with the JSONL content and the
|
||||
``message_count`` watermark from the upload, or ``None`` if not found.
|
||||
Returns the JSONL content string, or ``None`` if not found.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
@@ -412,6 +330,10 @@ async def download_transcript(
|
||||
try:
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
logger.info(
|
||||
f"[Transcript] Downloaded {len(content)} bytes for session {session_id}"
|
||||
)
|
||||
return content
|
||||
except FileNotFoundError:
|
||||
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
|
||||
return None
|
||||
@@ -419,36 +341,6 @@ async def download_transcript(
|
||||
logger.warning(f"[Transcript] Failed to download transcript: {e}")
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
message_count = 0
|
||||
uploaded_at = 0.0
|
||||
try:
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
if isinstance(storage, GCSWorkspaceStorage):
|
||||
blob = f"workspaces/{mwid}/{mfid}/{mfname}"
|
||||
meta_path = f"gcs://{storage.bucket_name}/{blob}"
|
||||
else:
|
||||
meta_path = f"local://{mwid}/{mfid}/{mfname}"
|
||||
|
||||
meta_data = await storage.retrieve(meta_path)
|
||||
meta = json.loads(meta_data.decode("utf-8"))
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except (FileNotFoundError, json.JSONDecodeError, Exception):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
|
||||
logger.debug(
|
||||
f"[Transcript] Downloaded {len(content)}B "
|
||||
f"(msg_count={message_count}) for session {session_id}"
|
||||
)
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
uploaded_at=uploaded_at,
|
||||
)
|
||||
|
||||
|
||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
"""Delete transcript from bucket storage (e.g. after resume failure)."""
|
||||
|
||||
@@ -387,7 +387,7 @@ async def stream_chat_completion(
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
|
||||
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
|
||||
extra={
|
||||
@@ -404,7 +404,7 @@ async def stream_chat_completion(
|
||||
fetch_start = time.monotonic()
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
fetch_time = (time.monotonic() - fetch_start) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
|
||||
f"n_messages={len(session.messages) if session else 0}",
|
||||
extra={
|
||||
@@ -416,7 +416,7 @@ async def stream_chat_completion(
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] Using provided session, messages={len(session.messages)}",
|
||||
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||
)
|
||||
@@ -450,7 +450,7 @@ async def stream_chat_completion(
|
||||
message_length=len(message),
|
||||
)
|
||||
posthog_time = (time.monotonic() - posthog_start) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
|
||||
)
|
||||
@@ -458,7 +458,7 @@ async def stream_chat_completion(
|
||||
upsert_start = time.monotonic()
|
||||
session = await upsert_chat_session(session)
|
||||
upsert_time = (time.monotonic() - upsert_start) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
|
||||
)
|
||||
@@ -503,7 +503,7 @@ async def stream_chat_completion(
|
||||
prompt_start = time.monotonic()
|
||||
system_prompt, understanding = await _build_system_prompt(user_id)
|
||||
prompt_time = (time.monotonic() - prompt_start) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
|
||||
)
|
||||
@@ -537,7 +537,7 @@ async def stream_chat_completion(
|
||||
|
||||
# Only yield message start for the initial call, not for continuations.
|
||||
setup_time = (time.monotonic() - completion_start) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||
)
|
||||
@@ -548,7 +548,7 @@ async def stream_chat_completion(
|
||||
yield StreamStartStep()
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
logger.info(
|
||||
"[TIMING] Calling _stream_chat_chunks",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
@@ -988,7 +988,7 @@ async def _stream_chat_chunks(
|
||||
if session.user_id:
|
||||
log_meta["user_id"] = session.user_id
|
||||
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
|
||||
f"user={session.user_id}, n_messages={len(session.messages)}",
|
||||
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||
@@ -1011,7 +1011,7 @@ async def _stream_chat_chunks(
|
||||
base_url=config.base_url,
|
||||
)
|
||||
context_time = (time_module.perf_counter() - context_start) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
|
||||
)
|
||||
@@ -1053,7 +1053,7 @@ async def _stream_chat_chunks(
|
||||
retry_info = (
|
||||
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
|
||||
)
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -1093,7 +1093,7 @@ async def _stream_chat_chunks(
|
||||
extra_body=extra_body,
|
||||
)
|
||||
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
|
||||
)
|
||||
@@ -1142,7 +1142,7 @@ async def _stream_chat_chunks(
|
||||
ttfc = (
|
||||
time_module.perf_counter() - api_call_start
|
||||
) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
|
||||
f"(since API call), n_chunks={chunk_count}",
|
||||
extra={
|
||||
@@ -1210,7 +1210,7 @@ async def _stream_chat_chunks(
|
||||
)
|
||||
emitted_start_for_idx.add(idx)
|
||||
stream_duration = time_module.perf_counter() - api_call_start
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
|
||||
f"duration={stream_duration:.2f}s, "
|
||||
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
|
||||
@@ -1244,7 +1244,7 @@ async def _stream_chat_chunks(
|
||||
raise
|
||||
|
||||
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
|
||||
f"session={session.session_id}, user={session.user_id}",
|
||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||
@@ -1494,8 +1494,8 @@ async def _yield_tool_call(
|
||||
# Mark stream registry task as failed if it was created
|
||||
try:
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
except Exception as mark_err:
|
||||
logger.warning(f"Failed to mark task {task_id} as failed: {mark_err}")
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(
|
||||
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
||||
)
|
||||
|
||||
@@ -143,7 +143,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
"Transcript was not uploaded to bucket after turn 1 — "
|
||||
"Stop hook may not have fired or transcript was too small"
|
||||
)
|
||||
logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes")
|
||||
logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes")
|
||||
|
||||
# Reload session for turn 2
|
||||
session = await get_chat_session(session.session_id, test_user_id)
|
||||
|
||||
@@ -117,7 +117,7 @@ async def create_task(
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
@@ -135,7 +135,7 @@ async def create_task(
|
||||
redis_start = time.perf_counter()
|
||||
redis = await get_redis_async()
|
||||
redis_time = (time.perf_counter() - redis_start) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
|
||||
)
|
||||
@@ -158,7 +158,7 @@ async def create_task(
|
||||
},
|
||||
)
|
||||
hset_time = (time.perf_counter() - hset_start) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] redis.hset took {hset_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
|
||||
)
|
||||
@@ -169,7 +169,7 @@ async def create_task(
|
||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
|
||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||
)
|
||||
@@ -230,7 +230,7 @@ async def publish_chunk(
|
||||
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
|
||||
or total_time > 50
|
||||
):
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -279,7 +279,7 @@ async def subscribe_to_task(
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
|
||||
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
|
||||
)
|
||||
@@ -289,14 +289,14 @@ async def subscribe_to_task(
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
hgetall_time = (time.perf_counter() - redis_start) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
||||
)
|
||||
|
||||
if not meta:
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -335,7 +335,7 @@ async def subscribe_to_task(
|
||||
xread_start = time.perf_counter()
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -363,7 +363,7 @@ async def subscribe_to_task(
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay message: {e}")
|
||||
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -376,7 +376,7 @@ async def subscribe_to_task(
|
||||
|
||||
# Step 2: If task is still running, start stream listener for live updates
|
||||
if task_status == "running":
|
||||
logger.debug(
|
||||
logger.info(
|
||||
"[TIMING] Task still running, starting _stream_listener",
|
||||
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||
)
|
||||
@@ -387,14 +387,14 @@ async def subscribe_to_task(
|
||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||
else:
|
||||
# Task is completed/failed - add finish marker
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] Task already {task_status}, adding StreamFinish",
|
||||
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||
)
|
||||
await subscriber_queue.put(StreamFinish())
|
||||
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
|
||||
f"n_messages_replayed={replayed_count}",
|
||||
extra={
|
||||
@@ -433,7 +433,7 @@ async def _stream_listener(
|
||||
if log_meta is None:
|
||||
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
||||
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
|
||||
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
|
||||
)
|
||||
@@ -462,7 +462,7 @@ async def _stream_listener(
|
||||
|
||||
if messages:
|
||||
msg_count = sum(len(msgs) for _, msgs in messages)
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -475,7 +475,7 @@ async def _stream_listener(
|
||||
)
|
||||
elif xread_time > 1000:
|
||||
# Only log timeouts (30s blocking)
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -526,7 +526,7 @@ async def _stream_listener(
|
||||
if first_message_time is None:
|
||||
first_message_time = time.perf_counter()
|
||||
elapsed = (first_message_time - start_time) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -568,7 +568,7 @@ async def _stream_listener(
|
||||
# Stop listening on finish
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -587,7 +587,7 @@ async def _stream_listener(
|
||||
|
||||
except asyncio.CancelledError:
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -619,7 +619,7 @@ async def _stream_listener(
|
||||
finally:
|
||||
# Clean up listener task mapping on exit
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
|
||||
f"delivered={messages_delivered}, xread_count={xread_count}",
|
||||
extra={
|
||||
@@ -829,13 +829,10 @@ async def get_active_task_for_session(
|
||||
)
|
||||
await mark_task_completed(task_id, "failed")
|
||||
continue
|
||||
except (ValueError, TypeError) as exc:
|
||||
logger.warning(
|
||||
f"[TASK_LOOKUP] Failed to parse created_at "
|
||||
f"for task {task_id[:8]}...: {exc}"
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
|
||||
)
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ query SearchFeatureRequests($term: String!, $filter: IssueFilter, $first: Int) {
|
||||
id
|
||||
identifier
|
||||
title
|
||||
description
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -204,6 +205,7 @@ class SearchFeatureRequestsTool(BaseTool):
|
||||
id=node["id"],
|
||||
identifier=node["identifier"],
|
||||
title=node["title"],
|
||||
description=node.get("description"),
|
||||
)
|
||||
for node in nodes
|
||||
]
|
||||
@@ -237,11 +239,7 @@ class CreateFeatureRequestTool(BaseTool):
|
||||
"Create a new feature request or add a customer need to an existing one. "
|
||||
"Always search first with search_feature_requests to avoid duplicates. "
|
||||
"If a matching request exists, pass its ID as existing_issue_id to add "
|
||||
"the user's need to it instead of creating a duplicate. "
|
||||
"IMPORTANT: Never include personally identifiable information (PII) in "
|
||||
"the title or description — no names, emails, phone numbers, company "
|
||||
"names, or other identifying details. Write titles and descriptions in "
|
||||
"generic, feature-focused language."
|
||||
"the user's need to it instead of creating a duplicate."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -251,20 +249,11 @@ class CreateFeatureRequestTool(BaseTool):
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Title for the feature request. Must be generic and "
|
||||
"feature-focused — do not include any user names, emails, "
|
||||
"company names, or other PII."
|
||||
),
|
||||
"description": "Title for the feature request.",
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Detailed description of what the user wants and why. "
|
||||
"Must not contain any personally identifiable information "
|
||||
"(PII) — describe the feature need generically without "
|
||||
"referencing specific users, companies, or contact details."
|
||||
),
|
||||
"description": "Detailed description of what the user wants and why.",
|
||||
},
|
||||
"existing_issue_id": {
|
||||
"type": "string",
|
||||
|
||||
@@ -117,11 +117,13 @@ class TestSearchFeatureRequestsTool:
|
||||
"id": "id-1",
|
||||
"identifier": "FR-1",
|
||||
"title": "Dark mode",
|
||||
"description": "Add dark mode support",
|
||||
},
|
||||
{
|
||||
"id": "id-2",
|
||||
"identifier": "FR-2",
|
||||
"title": "Dark theme",
|
||||
"description": None,
|
||||
},
|
||||
]
|
||||
patcher, _ = _mock_linear_config(query_return=_search_response(nodes))
|
||||
|
||||
@@ -486,6 +486,7 @@ class FeatureRequestInfo(BaseModel):
|
||||
id: str
|
||||
identifier: str
|
||||
title: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class FeatureRequestSearchResponse(ToolResponseBase):
|
||||
|
||||
@@ -312,18 +312,8 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
is_small_file = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
|
||||
is_text_file = self._is_text_mime_type(file_info.mime_type)
|
||||
|
||||
# Return inline content for small text/image files (unless force_download_url)
|
||||
is_image_file = file_info.mime_type in {
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
}
|
||||
if (
|
||||
is_small_file
|
||||
and (is_text_file or is_image_file)
|
||||
and not force_download_url
|
||||
):
|
||||
# Return inline content for small text files (unless force_download_url)
|
||||
if is_small_file and is_text_file and not force_download_url:
|
||||
content = await manager.read_file_by_id(target_file_id)
|
||||
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
|
||||
@@ -93,15 +93,7 @@ from backend.data.user import (
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.data.workspace import (
|
||||
count_workspace_files,
|
||||
create_workspace_file,
|
||||
get_or_create_workspace,
|
||||
get_workspace_file,
|
||||
get_workspace_file_by_path,
|
||||
list_workspace_files,
|
||||
soft_delete_workspace_file,
|
||||
)
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
@@ -282,13 +274,7 @@ class DatabaseManager(AppService):
|
||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||
|
||||
# ============ Workspace ============ #
|
||||
count_workspace_files = _(count_workspace_files)
|
||||
create_workspace_file = _(create_workspace_file)
|
||||
get_or_create_workspace = _(get_or_create_workspace)
|
||||
get_workspace_file = _(get_workspace_file)
|
||||
get_workspace_file_by_path = _(get_workspace_file_by_path)
|
||||
list_workspace_files = _(list_workspace_files)
|
||||
soft_delete_workspace_file = _(soft_delete_workspace_file)
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = _(get_business_understanding)
|
||||
@@ -452,13 +438,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_user_execution_summary_data = d.get_user_execution_summary_data
|
||||
|
||||
# ============ Workspace ============ #
|
||||
count_workspace_files = d.count_workspace_files
|
||||
create_workspace_file = d.create_workspace_file
|
||||
get_or_create_workspace = d.get_or_create_workspace
|
||||
get_workspace_file = d.get_workspace_file
|
||||
get_workspace_file_by_path = d.get_workspace_file_by_path
|
||||
list_workspace_files = d.list_workspace_files
|
||||
soft_delete_workspace_file = d.soft_delete_workspace_file
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = d.get_business_understanding
|
||||
|
||||
@@ -164,23 +164,21 @@ async def create_workspace_file(
|
||||
|
||||
async def get_workspace_file(
|
||||
file_id: str,
|
||||
workspace_id: str,
|
||||
workspace_id: Optional[str] = None,
|
||||
) -> Optional[WorkspaceFile]:
|
||||
"""
|
||||
Get a workspace file by ID.
|
||||
|
||||
Args:
|
||||
file_id: The file ID
|
||||
workspace_id: Workspace ID for scoping (required)
|
||||
workspace_id: Optional workspace ID for validation
|
||||
|
||||
Returns:
|
||||
WorkspaceFile instance or None
|
||||
"""
|
||||
where_clause: UserWorkspaceFileWhereInput = {
|
||||
"id": file_id,
|
||||
"isDeleted": False,
|
||||
"workspaceId": workspace_id,
|
||||
}
|
||||
where_clause: dict = {"id": file_id, "isDeleted": False}
|
||||
if workspace_id:
|
||||
where_clause["workspaceId"] = workspace_id
|
||||
|
||||
file = await UserWorkspaceFile.prisma().find_first(where=where_clause)
|
||||
return WorkspaceFile.from_db(file) if file else None
|
||||
@@ -270,7 +268,7 @@ async def count_workspace_files(
|
||||
Returns:
|
||||
Number of files
|
||||
"""
|
||||
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
|
||||
where_clause: dict = {"workspaceId": workspace_id}
|
||||
if not include_deleted:
|
||||
where_clause["isDeleted"] = False
|
||||
|
||||
@@ -285,7 +283,7 @@ async def count_workspace_files(
|
||||
|
||||
async def soft_delete_workspace_file(
|
||||
file_id: str,
|
||||
workspace_id: str,
|
||||
workspace_id: Optional[str] = None,
|
||||
) -> Optional[WorkspaceFile]:
|
||||
"""
|
||||
Soft-delete a workspace file.
|
||||
@@ -295,7 +293,7 @@ async def soft_delete_workspace_file(
|
||||
|
||||
Args:
|
||||
file_id: The file ID
|
||||
workspace_id: Workspace ID for scoping (required)
|
||||
workspace_id: Optional workspace ID for validation
|
||||
|
||||
Returns:
|
||||
Updated WorkspaceFile instance or None if not found
|
||||
|
||||
@@ -28,7 +28,7 @@ from typing import (
|
||||
import httpx
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, responses
|
||||
from prisma.errors import DataError, UniqueViolationError
|
||||
from prisma.errors import DataError
|
||||
from pydantic import BaseModel, TypeAdapter, create_model
|
||||
|
||||
import backend.util.exceptions as exceptions
|
||||
@@ -201,7 +201,6 @@ EXCEPTION_MAPPING = {
|
||||
UnhealthyServiceError,
|
||||
HTTPClientError,
|
||||
HTTPServerError,
|
||||
UniqueViolationError,
|
||||
*[
|
||||
ErrorType
|
||||
for _, ErrorType in inspect.getmembers(exceptions)
|
||||
@@ -417,9 +416,6 @@ class AppService(BaseAppService, ABC):
|
||||
self.fastapi_app.add_exception_handler(
|
||||
DataError, self._handle_internal_http_error(400)
|
||||
)
|
||||
self.fastapi_app.add_exception_handler(
|
||||
UniqueViolationError, self._handle_internal_http_error(400)
|
||||
)
|
||||
self.fastapi_app.add_exception_handler(
|
||||
Exception, self._handle_internal_http_error(500)
|
||||
)
|
||||
@@ -482,7 +478,6 @@ def get_service_client(
|
||||
# Don't retry these specific exceptions that won't be fixed by retrying
|
||||
ValueError, # Invalid input/parameters
|
||||
DataError, # Prisma data integrity errors (foreign key, unique constraints)
|
||||
UniqueViolationError, # Unique constraint violations
|
||||
KeyError, # Missing required data
|
||||
TypeError, # Wrong data types
|
||||
AttributeError, # Missing attributes
|
||||
@@ -599,15 +594,6 @@ def get_service_client(
|
||||
if error_response and error_response.type in EXCEPTION_MAPPING:
|
||||
exception_class = EXCEPTION_MAPPING[error_response.type]
|
||||
args = error_response.args or [str(e)]
|
||||
|
||||
# Prisma DataError subclasses expect a dict `data` arg,
|
||||
# but RPC serialization only preserves the string message
|
||||
# from exc.args. Wrap it in the expected structure so
|
||||
# the constructor doesn't crash on `.get()`.
|
||||
if issubclass(exception_class, DataError):
|
||||
msg = str(args[0]) if args else str(e)
|
||||
raise exception_class({"user_facing_error": {"message": msg}})
|
||||
|
||||
raise exception_class(*args)
|
||||
|
||||
# Otherwise categorize by HTTP status code
|
||||
|
||||
@@ -6,7 +6,6 @@ from unittest.mock import Mock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from prisma.errors import DataError, UniqueViolationError
|
||||
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
@@ -448,39 +447,6 @@ class TestHTTPErrorRetryBehavior:
|
||||
|
||||
assert "Invalid parameter value" in str(exc_info.value)
|
||||
|
||||
def test_prisma_data_error_reconstructed_correctly(self):
|
||||
"""Test that DataError subclasses (e.g. UniqueViolationError) are
|
||||
reconstructed without crashing.
|
||||
|
||||
Prisma's DataError.__init__ expects a dict `data` arg with
|
||||
a 'user_facing_error' key. RPC serialization only preserves the
|
||||
string message via exc.args, so the client must wrap it in the
|
||||
expected dict structure.
|
||||
"""
|
||||
for exc_type in [DataError, UniqueViolationError]:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.json.return_value = {
|
||||
"type": exc_type.__name__,
|
||||
"args": ["Unique constraint failed on the fields: (`path`)"],
|
||||
}
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"400 Bad Request", request=Mock(), response=mock_response
|
||||
)
|
||||
|
||||
client = get_service_client(ServiceTestClient)
|
||||
|
||||
with pytest.raises(exc_type) as exc_info:
|
||||
client._handle_call_method_response( # type: ignore[attr-defined]
|
||||
response=mock_response, method_name="test_method"
|
||||
)
|
||||
|
||||
# The exception should have the message preserved
|
||||
assert "Unique constraint" in str(exc_info.value)
|
||||
# And should have the expected data structure (not crash)
|
||||
assert hasattr(exc_info.value, "data")
|
||||
assert isinstance(exc_info.value.data, dict)
|
||||
|
||||
def test_client_error_status_codes_coverage(self):
|
||||
"""Test that various 4xx status codes are all wrapped as HTTPClientError."""
|
||||
client_error_codes = [400, 401, 403, 404, 405, 409, 422, 429]
|
||||
|
||||
@@ -12,8 +12,15 @@ from typing import Optional
|
||||
|
||||
from prisma.errors import UniqueViolationError
|
||||
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
from backend.data.workspace import (
|
||||
WorkspaceFile,
|
||||
count_workspace_files,
|
||||
create_workspace_file,
|
||||
get_workspace_file,
|
||||
get_workspace_file_by_path,
|
||||
list_workspace_files,
|
||||
soft_delete_workspace_file,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
||||
@@ -118,9 +125,8 @@ class WorkspaceManager:
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
db = workspace_db()
|
||||
resolved_path = self._resolve_path(path)
|
||||
file = await db.get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||
file = await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||
if file is None:
|
||||
raise FileNotFoundError(f"File not found at path: {resolved_path}")
|
||||
|
||||
@@ -140,8 +146,7 @@ class WorkspaceManager:
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
db = workspace_db()
|
||||
file = await db.get_workspace_file(file_id, self.workspace_id)
|
||||
file = await get_workspace_file(file_id, self.workspace_id)
|
||||
if file is None:
|
||||
raise FileNotFoundError(f"File not found: {file_id}")
|
||||
|
||||
@@ -199,10 +204,8 @@ class WorkspaceManager:
|
||||
# For overwrite=True, we let the write proceed and handle via UniqueViolationError
|
||||
# This ensures the new file is written to storage BEFORE the old one is deleted,
|
||||
# preventing data loss if the new write fails
|
||||
db = workspace_db()
|
||||
|
||||
if not overwrite:
|
||||
existing = await db.get_workspace_file_by_path(self.workspace_id, path)
|
||||
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
||||
if existing is not None:
|
||||
raise ValueError(f"File already exists at path: {path}")
|
||||
|
||||
@@ -229,7 +232,7 @@ class WorkspaceManager:
|
||||
# Create database record - handle race condition where another request
|
||||
# created a file at the same path between our check and create
|
||||
try:
|
||||
file = await db.create_workspace_file(
|
||||
file = await create_workspace_file(
|
||||
workspace_id=self.workspace_id,
|
||||
file_id=file_id,
|
||||
name=filename,
|
||||
@@ -243,12 +246,12 @@ class WorkspaceManager:
|
||||
# Race condition: another request created a file at this path
|
||||
if overwrite:
|
||||
# Re-fetch and delete the conflicting file, then retry
|
||||
existing = await db.get_workspace_file_by_path(self.workspace_id, path)
|
||||
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
||||
if existing:
|
||||
await self.delete_file(existing.id)
|
||||
# Retry the create - if this also fails, clean up storage file
|
||||
try:
|
||||
file = await db.create_workspace_file(
|
||||
file = await create_workspace_file(
|
||||
workspace_id=self.workspace_id,
|
||||
file_id=file_id,
|
||||
name=filename,
|
||||
@@ -311,9 +314,8 @@ class WorkspaceManager:
|
||||
List of WorkspaceFile instances
|
||||
"""
|
||||
effective_path = self._get_effective_path(path, include_all_sessions)
|
||||
db = workspace_db()
|
||||
|
||||
return await db.list_workspace_files(
|
||||
return await list_workspace_files(
|
||||
workspace_id=self.workspace_id,
|
||||
path_prefix=effective_path,
|
||||
limit=limit,
|
||||
@@ -330,8 +332,7 @@ class WorkspaceManager:
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
db = workspace_db()
|
||||
file = await db.get_workspace_file(file_id, self.workspace_id)
|
||||
file = await get_workspace_file(file_id, self.workspace_id)
|
||||
if file is None:
|
||||
return False
|
||||
|
||||
@@ -344,7 +345,7 @@ class WorkspaceManager:
|
||||
# Continue with database soft-delete even if storage delete fails
|
||||
|
||||
# Soft-delete database record
|
||||
result = await db.soft_delete_workspace_file(file_id, self.workspace_id)
|
||||
result = await soft_delete_workspace_file(file_id, self.workspace_id)
|
||||
return result is not None
|
||||
|
||||
async def get_download_url(self, file_id: str, expires_in: int = 3600) -> str:
|
||||
@@ -361,8 +362,7 @@ class WorkspaceManager:
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
db = workspace_db()
|
||||
file = await db.get_workspace_file(file_id, self.workspace_id)
|
||||
file = await get_workspace_file(file_id, self.workspace_id)
|
||||
if file is None:
|
||||
raise FileNotFoundError(f"File not found: {file_id}")
|
||||
|
||||
@@ -379,8 +379,7 @@ class WorkspaceManager:
|
||||
Returns:
|
||||
WorkspaceFile instance or None
|
||||
"""
|
||||
db = workspace_db()
|
||||
return await db.get_workspace_file(file_id, self.workspace_id)
|
||||
return await get_workspace_file(file_id, self.workspace_id)
|
||||
|
||||
async def get_file_info_by_path(self, path: str) -> Optional[WorkspaceFile]:
|
||||
"""
|
||||
@@ -395,9 +394,8 @@ class WorkspaceManager:
|
||||
Returns:
|
||||
WorkspaceFile instance or None
|
||||
"""
|
||||
db = workspace_db()
|
||||
resolved_path = self._resolve_path(path)
|
||||
return await db.get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||
return await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||
|
||||
async def get_file_count(
|
||||
self,
|
||||
@@ -419,8 +417,7 @@ class WorkspaceManager:
|
||||
Number of files
|
||||
"""
|
||||
effective_path = self._get_effective_path(path, include_all_sessions)
|
||||
db = workspace_db()
|
||||
|
||||
return await db.count_workspace_files(
|
||||
return await count_workspace_files(
|
||||
self.workspace_id, path_prefix=effective_path
|
||||
)
|
||||
|
||||
@@ -93,14 +93,7 @@ class WorkspaceStorageBackend(ABC):
|
||||
|
||||
|
||||
class GCSWorkspaceStorage(WorkspaceStorageBackend):
|
||||
"""Google Cloud Storage implementation for workspace storage.
|
||||
|
||||
Each instance owns a single ``aiohttp.ClientSession`` and GCS async
|
||||
client. Because ``ClientSession`` is bound to the event loop on which it
|
||||
was created, callers that run on separate loops (e.g. copilot executor
|
||||
worker threads) **must** obtain their own ``GCSWorkspaceStorage`` instance
|
||||
via :func:`get_workspace_storage` which is event-loop-aware.
|
||||
"""
|
||||
"""Google Cloud Storage implementation for workspace storage."""
|
||||
|
||||
def __init__(self, bucket_name: str):
|
||||
self.bucket_name = bucket_name
|
||||
@@ -344,73 +337,60 @@ class LocalWorkspaceStorage(WorkspaceStorageBackend):
|
||||
raise ValueError(f"Invalid storage path format: {storage_path}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Storage instance management
|
||||
# ---------------------------------------------------------------------------
|
||||
# ``aiohttp.ClientSession`` is bound to the event loop where it is created.
|
||||
# The copilot executor runs each worker in its own thread with a dedicated
|
||||
# event loop, so a single global ``GCSWorkspaceStorage`` instance would break.
|
||||
#
|
||||
# For **local storage** a single shared instance is fine (no async I/O).
|
||||
# For **GCS storage** we keep one instance *per event loop* so every loop
|
||||
# gets its own ``ClientSession``.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_local_storage: Optional[LocalWorkspaceStorage] = None
|
||||
_gcs_storages: dict[int, GCSWorkspaceStorage] = {}
|
||||
# Global storage backend instance
|
||||
_workspace_storage: Optional[WorkspaceStorageBackend] = None
|
||||
_storage_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_workspace_storage() -> WorkspaceStorageBackend:
|
||||
"""Return a workspace storage backend for the **current** event loop.
|
||||
|
||||
* Local storage → single shared instance (no event-loop affinity).
|
||||
* GCS storage → one instance per event loop to avoid cross-loop
|
||||
``aiohttp`` errors.
|
||||
"""
|
||||
global _local_storage
|
||||
Get the workspace storage backend instance.
|
||||
|
||||
config = Config()
|
||||
Uses GCS if media_gcs_bucket_name is configured, otherwise uses local storage.
|
||||
"""
|
||||
global _workspace_storage
|
||||
|
||||
# --- Local storage (shared) ---
|
||||
if not config.media_gcs_bucket_name:
|
||||
if _local_storage is None:
|
||||
storage_dir = (
|
||||
config.workspace_storage_dir if config.workspace_storage_dir else None
|
||||
)
|
||||
logger.info(f"Using local workspace storage: {storage_dir or 'default'}")
|
||||
_local_storage = LocalWorkspaceStorage(storage_dir)
|
||||
return _local_storage
|
||||
if _workspace_storage is None:
|
||||
async with _storage_lock:
|
||||
if _workspace_storage is None:
|
||||
config = Config()
|
||||
|
||||
# --- GCS storage (per event loop) ---
|
||||
loop_id = id(asyncio.get_running_loop())
|
||||
if loop_id not in _gcs_storages:
|
||||
logger.info(
|
||||
f"Creating GCS workspace storage for loop {loop_id}: "
|
||||
f"{config.media_gcs_bucket_name}"
|
||||
)
|
||||
_gcs_storages[loop_id] = GCSWorkspaceStorage(config.media_gcs_bucket_name)
|
||||
return _gcs_storages[loop_id]
|
||||
if config.media_gcs_bucket_name:
|
||||
logger.info(
|
||||
f"Using GCS workspace storage: {config.media_gcs_bucket_name}"
|
||||
)
|
||||
_workspace_storage = GCSWorkspaceStorage(
|
||||
config.media_gcs_bucket_name
|
||||
)
|
||||
else:
|
||||
storage_dir = (
|
||||
config.workspace_storage_dir
|
||||
if config.workspace_storage_dir
|
||||
else None
|
||||
)
|
||||
logger.info(
|
||||
f"Using local workspace storage: {storage_dir or 'default'}"
|
||||
)
|
||||
_workspace_storage = LocalWorkspaceStorage(storage_dir)
|
||||
|
||||
return _workspace_storage
|
||||
|
||||
|
||||
async def shutdown_workspace_storage() -> None:
|
||||
"""Shut down workspace storage for the **current** event loop.
|
||||
|
||||
Closes the ``aiohttp`` session owned by the current loop's GCS instance.
|
||||
Each worker thread should call this on its own loop before the loop is
|
||||
destroyed. The REST API lifespan hook calls it for the main server loop.
|
||||
"""
|
||||
global _local_storage
|
||||
Properly shutdown the global workspace storage backend.
|
||||
|
||||
loop_id = id(asyncio.get_running_loop())
|
||||
storage = _gcs_storages.pop(loop_id, None)
|
||||
if storage is not None:
|
||||
await storage.close()
|
||||
Closes aiohttp sessions and other resources for GCS backend.
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
global _workspace_storage
|
||||
|
||||
# Clear local storage only when the last GCS instance is gone
|
||||
# (i.e. full shutdown, not just a single worker stopping).
|
||||
if not _gcs_storages:
|
||||
_local_storage = None
|
||||
if _workspace_storage is not None:
|
||||
async with _storage_lock:
|
||||
if _workspace_storage is not None:
|
||||
if isinstance(_workspace_storage, GCSWorkspaceStorage):
|
||||
await _workspace_storage.close()
|
||||
_workspace_storage = None
|
||||
|
||||
|
||||
def compute_file_checksum(content: bytes) -> str:
|
||||
|
||||
99
autogpt_platform/backend/poetry.lock
generated
99
autogpt_platform/backend/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aio-pika"
|
||||
@@ -374,7 +374,7 @@ description = "LTS Port of Python audioop"
|
||||
optional = false
|
||||
python-versions = ">=3.13"
|
||||
groups = ["main"]
|
||||
markers = "python_version >= \"3.13\""
|
||||
markers = "python_version == \"3.13\""
|
||||
files = [
|
||||
{file = "audioop_lts-0.2.2-cp313-abi3-macosx_10_13_universal2.whl", hash = "sha256:fd3d4602dc64914d462924a08c1a9816435a2155d74f325853c1f1ac3b2d9800"},
|
||||
{file = "audioop_lts-0.2.2-cp313-abi3-macosx_10_13_x86_64.whl", hash = "sha256:550c114a8df0aafe9a05442a1162dfc8fec37e9af1d625ae6060fed6e756f303"},
|
||||
@@ -474,7 +474,7 @@ description = "Backport of asyncio.Runner, a context manager that controls event
|
||||
optional = false
|
||||
python-versions = "<3.11,>=3.8"
|
||||
groups = ["main"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
|
||||
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
|
||||
@@ -487,7 +487,7 @@ description = "Backport of CPython tarfile module"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
markers = "python_version <= \"3.11\""
|
||||
markers = "python_version < \"3.12\""
|
||||
files = [
|
||||
{file = "backports.tarfile-1.2.0-py3-none-any.whl", hash = "sha256:77e284d754527b01fb1e6fa8a1afe577858ebe4e9dad8919e34c862cb399bc34"},
|
||||
{file = "backports_tarfile-1.2.0.tar.gz", hash = "sha256:d75e02c268746e1b8144c278978b6e98e85de6ad16f8e4b0844a154557eca991"},
|
||||
@@ -563,6 +563,18 @@ webencodings = "*"
|
||||
[package.extras]
|
||||
css = ["tinycss2 (>=1.1.0,<1.5)"]
|
||||
|
||||
[[package]]
|
||||
name = "bracex"
|
||||
version = "2.6"
|
||||
description = "Bash style brace expander."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "bracex-2.6-py3-none-any.whl", hash = "sha256:0b0049264e7340b3ec782b5cb99beb325f36c3782a32e36e876452fd49a09952"},
|
||||
{file = "bracex-2.6.tar.gz", hash = "sha256:98f1347cd77e22ee8d967a30ad4e310b233f7754dbf31ff3fceb76145ba47dc7"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "browserbase"
|
||||
version = "1.4.0"
|
||||
@@ -659,7 +671,6 @@ description = "Foreign Function Interface for Python calling C code."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
markers = "platform_python_implementation != \"PyPy\" or sys_platform == \"darwin\""
|
||||
files = [
|
||||
{file = "cffi-2.0.0-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:0cf2d91ecc3fcc0625c2c530fe004f82c110405f101548512cce44322fa8ac44"},
|
||||
{file = "cffi-2.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f73b96c41e3b2adedc34a7356e64c8eb96e03a3782b535e043a986276ce12a49"},
|
||||
@@ -1172,6 +1183,18 @@ idna = ["idna (>=3.10)"]
|
||||
trio = ["trio (>=0.30)"]
|
||||
wmi = ["wmi (>=1.5.1) ; platform_system == \"Windows\""]
|
||||
|
||||
[[package]]
|
||||
name = "dockerfile-parse"
|
||||
version = "2.0.1"
|
||||
description = "Python library for Dockerfile manipulation"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "dockerfile-parse-2.0.1.tar.gz", hash = "sha256:3184ccdc513221983e503ac00e1aa504a2aa8f84e5de673c46b0b6eee99ec7bc"},
|
||||
{file = "dockerfile_parse-2.0.1-py2.py3-none-any.whl", hash = "sha256:bdffd126d2eb26acf1066acb54cb2e336682e1d72b974a40894fac76a4df17f6"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docstring-parser"
|
||||
version = "0.17.0"
|
||||
@@ -1258,40 +1281,43 @@ pgp = ["gpg"]
|
||||
|
||||
[[package]]
|
||||
name = "e2b"
|
||||
version = "1.11.1"
|
||||
version = "2.13.2"
|
||||
description = "E2B SDK that give agents cloud environments"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
python-versions = "<4.0,>=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "e2b-1.11.1-py3-none-any.whl", hash = "sha256:1ecb123873788472731c101939a494ab852cbcce0f913df6f7ecb194ae932130"},
|
||||
{file = "e2b-1.11.1.tar.gz", hash = "sha256:7f7b6f238208d0a23353bb0da01f91a924321b57c61b176506862cbc1493ce8c"},
|
||||
{file = "e2b-2.13.2-py3-none-any.whl", hash = "sha256:d91d5293bc0dd1917c72a6e6b35e86513607be2666a14ae18c57b921e7864de4"},
|
||||
{file = "e2b-2.13.2.tar.gz", hash = "sha256:c0e81a3920091874fdf73c0b8f376b28766212db9f1cea5d8bd56a2e95d2436c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
attrs = ">=23.2.0"
|
||||
dockerfile-parse = ">=2.0.1,<3.0.0"
|
||||
httpcore = ">=1.0.5,<2.0.0"
|
||||
httpx = ">=0.27.0,<1.0.0"
|
||||
packaging = ">=24.1"
|
||||
protobuf = ">=4.21.0"
|
||||
python-dateutil = ">=2.8.2"
|
||||
rich = ">=14.0.0"
|
||||
typing-extensions = ">=4.1.0"
|
||||
wcmatch = ">=10.1,<11.0"
|
||||
|
||||
[[package]]
|
||||
name = "e2b-code-interpreter"
|
||||
version = "1.5.2"
|
||||
version = "2.4.1"
|
||||
description = "E2B Code Interpreter - Stateful code execution"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "e2b_code_interpreter-1.5.2-py3-none-any.whl", hash = "sha256:5c3188d8f25226b28fef4b255447cc6a4c36afb748bdd5180b45be486d5169f3"},
|
||||
{file = "e2b_code_interpreter-1.5.2.tar.gz", hash = "sha256:3bd6ea70596290e85aaf0a2f19f28bf37a5e73d13086f5e6a0080bb591c5a547"},
|
||||
{file = "e2b_code_interpreter-2.4.1-py3-none-any.whl", hash = "sha256:15d35f025b4a15033e119f2e12e7ac65657ad2b5a013fa9149e74581fbee778a"},
|
||||
{file = "e2b_code_interpreter-2.4.1.tar.gz", hash = "sha256:4b15014ee0d0dfcdc3072e1f409cbb87ca48f48d53d75629b7257e5513b9e7dd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
attrs = ">=21.3.0"
|
||||
e2b = ">=1.5.4,<2.0.0"
|
||||
e2b = ">=2.7.0,<3.0.0"
|
||||
httpx = ">=0.20.0,<1.0.0"
|
||||
|
||||
[[package]]
|
||||
@@ -1361,7 +1387,7 @@ description = "Backport of PEP 654 (exception groups)"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main", "dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598"},
|
||||
{file = "exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219"},
|
||||
@@ -1843,16 +1869,16 @@ files = [
|
||||
google-auth = ">=2.14.1,<3.0.0"
|
||||
googleapis-common-protos = ">=1.56.2,<2.0.0"
|
||||
grpcio = [
|
||||
{version = ">=1.33.2,<2.0.0", optional = true, markers = "extra == \"grpc\""},
|
||||
{version = ">=1.49.1,<2.0.0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.33.2,<2.0.0", optional = true, markers = "extra == \"grpc\""},
|
||||
]
|
||||
grpcio-status = [
|
||||
{version = ">=1.33.2,<2.0.0", optional = true, markers = "extra == \"grpc\""},
|
||||
{version = ">=1.49.1,<2.0.0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.33.2,<2.0.0", optional = true, markers = "extra == \"grpc\""},
|
||||
]
|
||||
proto-plus = [
|
||||
{version = ">=1.22.3,<2.0.0"},
|
||||
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.3,<2.0.0"},
|
||||
]
|
||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
|
||||
requests = ">=2.18.0,<3.0.0"
|
||||
@@ -1963,8 +1989,8 @@ google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0", extras
|
||||
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0"
|
||||
grpcio = ">=1.33.2,<2.0.0"
|
||||
proto-plus = [
|
||||
{version = ">=1.22.3,<2.0.0"},
|
||||
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.3,<2.0.0"},
|
||||
]
|
||||
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
|
||||
|
||||
@@ -2024,9 +2050,9 @@ google-cloud-core = ">=2.0.0,<3.0.0"
|
||||
grpc-google-iam-v1 = ">=0.12.4,<1.0.0"
|
||||
opentelemetry-api = ">=1.9.0"
|
||||
proto-plus = [
|
||||
{version = ">=1.22.0,<2.0.0"},
|
||||
{version = ">=1.22.2,<2.0.0", markers = "python_version >= \"3.11\""},
|
||||
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.2,<2.0.0", markers = "python_version >= \"3.11\" and python_version < \"3.13\""},
|
||||
{version = ">=1.22.0,<2.0.0", markers = "python_version < \"3.11\""},
|
||||
]
|
||||
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
|
||||
|
||||
@@ -3870,7 +3896,7 @@ description = "Fundamental package for array computing in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "numpy-2.2.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b412caa66f72040e6d268491a59f2c43bf03eb6c96dd8f0307829feb7fa2b6fb"},
|
||||
{file = "numpy-2.2.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8e41fd67c52b86603a91c1a505ebaef50b3314de0213461c7a6e99c9a3beff90"},
|
||||
@@ -4355,9 +4381,9 @@ files = [
|
||||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
@@ -4600,8 +4626,8 @@ pinecone-plugin-interface = ">=0.0.7,<0.0.8"
|
||||
python-dateutil = ">=2.5.3"
|
||||
typing-extensions = ">=3.7.4"
|
||||
urllib3 = [
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.8\" and python_version < \"3.12\""},
|
||||
{version = ">=1.26.5", markers = "python_version >= \"3.12\" and python_version < \"4.0\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.8\" and python_version < \"3.12\""},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
@@ -5429,7 +5455,7 @@ description = "C parser in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
markers = "(platform_python_implementation != \"PyPy\" or sys_platform == \"darwin\") and implementation_name != \"PyPy\""
|
||||
markers = "implementation_name != \"PyPy\""
|
||||
files = [
|
||||
{file = "pycparser-3.0-py3-none-any.whl", hash = "sha256:b727414169a36b7d524c1c3e31839a521725078d7b2ff038656844266160a992"},
|
||||
{file = "pycparser-3.0.tar.gz", hash = "sha256:600f49d217304a5902ac3c37e1281c9fe94e4d0489de643a9504c5cdfdfc6b29"},
|
||||
@@ -6198,10 +6224,10 @@ files = [
|
||||
grpcio = ">=1.41.0"
|
||||
httpx = {version = ">=0.20.0", extras = ["http2"]}
|
||||
numpy = [
|
||||
{version = ">=1.21,<2.3.0", markers = "python_version == \"3.10\""},
|
||||
{version = ">=1.21", markers = "python_version == \"3.11\""},
|
||||
{version = ">=2.1.0", markers = "python_version == \"3.13\""},
|
||||
{version = ">=1.21", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.26", markers = "python_version == \"3.12\""},
|
||||
{version = ">=1.21,<2.3.0", markers = "python_version == \"3.10\""},
|
||||
]
|
||||
portalocker = ">=2.7.0,<4.0"
|
||||
protobuf = ">=3.20.0"
|
||||
@@ -7407,7 +7433,7 @@ description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main", "dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "tomli-2.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b5ef256a3fd497d4973c11bf142e9ed78b150d36f5773f1ca6088c230ffc5867"},
|
||||
{file = "tomli-2.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5572e41282d5268eb09a697c89a7bee84fae66511f87533a6f88bd2f7b652da9"},
|
||||
@@ -7931,6 +7957,21 @@ files = [
|
||||
[package.dependencies]
|
||||
anyio = ">=3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "wcmatch"
|
||||
version = "10.1"
|
||||
description = "Wildcard/glob file name matcher."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "wcmatch-10.1-py3-none-any.whl", hash = "sha256:5848ace7dbb0476e5e55ab63c6bbd529745089343427caa5537f230cc01beb8a"},
|
||||
{file = "wcmatch-10.1.tar.gz", hash = "sha256:f11f94208c8c8484a16f4f48638a85d771d9513f4ab3f37595978801cb9465af"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
bracex = ">=2.1.1"
|
||||
|
||||
[[package]]
|
||||
name = "webencodings"
|
||||
version = "0.5.1"
|
||||
@@ -8530,4 +8571,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "55e095de555482f0fe47de7695f390fe93e7bcf739b31c391b2e5e3c3d938ae3"
|
||||
content-hash = "38ebd9d612b127781d346935b696fb362d04826342adbc1e27fe7e1294478234"
|
||||
|
||||
@@ -20,7 +20,7 @@ claude-agent-sdk = "^0.1.0"
|
||||
click = "^8.2.0"
|
||||
cryptography = "^46.0"
|
||||
discord-py = "^2.5.2"
|
||||
e2b-code-interpreter = "^1.5.2"
|
||||
e2b-code-interpreter = "^2.4.1"
|
||||
elevenlabs = "^1.50.0"
|
||||
fastapi = "^0.128.6"
|
||||
feedparser = "^6.0.11"
|
||||
|
||||
@@ -1,39 +1,21 @@
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Switch } from "@/components/atoms/Switch/Switch";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { CaretDownIcon } from "@phosphor-icons/react";
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
};
|
||||
|
||||
export function NodeAdvancedToggle({ nodeId }: Props) {
|
||||
export const NodeAdvancedToggle = ({ nodeId }: { nodeId: string }) => {
|
||||
const showAdvanced = useNodeStore(
|
||||
(state) => state.nodeAdvancedStates[nodeId] || false,
|
||||
);
|
||||
const setShowAdvanced = useNodeStore((state) => state.setShowAdvanced);
|
||||
return (
|
||||
<div className="flex items-center justify-start gap-2 bg-white px-5 pb-3.5">
|
||||
<Button
|
||||
variant="ghost"
|
||||
className="h-fit min-w-0 p-0 hover:border-transparent hover:bg-transparent"
|
||||
onClick={() => setShowAdvanced(nodeId, !showAdvanced)}
|
||||
aria-expanded={showAdvanced}
|
||||
>
|
||||
<Text
|
||||
variant="body"
|
||||
as="span"
|
||||
className="flex items-center gap-2 !font-semibold text-slate-700"
|
||||
>
|
||||
Advanced{" "}
|
||||
<CaretDownIcon
|
||||
size={16}
|
||||
weight="bold"
|
||||
className={`transition-transform ${showAdvanced ? "rotate-180" : ""}`}
|
||||
aria-hidden
|
||||
/>
|
||||
</Text>
|
||||
</Button>
|
||||
<div className="flex items-center justify-between gap-2 rounded-b-xlarge border-t border-zinc-200 bg-white px-5 py-3.5">
|
||||
<Text variant="body" className="font-medium text-slate-700">
|
||||
Advanced
|
||||
</Text>
|
||||
<Switch
|
||||
onCheckedChange={(checked) => setShowAdvanced(nodeId, checked)}
|
||||
checked={showAdvanced}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -23,7 +23,6 @@ export function CopilotPage() {
|
||||
status,
|
||||
error,
|
||||
stop,
|
||||
isReconnecting,
|
||||
createSession,
|
||||
onSend,
|
||||
isLoadingSession,
|
||||
@@ -72,7 +71,6 @@ export function CopilotPage() {
|
||||
sessionId={sessionId}
|
||||
isLoadingSession={isLoadingSession}
|
||||
isCreatingSession={isCreatingSession}
|
||||
isReconnecting={isReconnecting}
|
||||
onCreateSession={createSession}
|
||||
onSend={onSend}
|
||||
onStop={stop}
|
||||
|
||||
@@ -14,8 +14,6 @@ export interface ChatContainerProps {
|
||||
sessionId: string | null;
|
||||
isLoadingSession: boolean;
|
||||
isCreatingSession: boolean;
|
||||
/** True when backend has an active stream but we haven't reconnected yet. */
|
||||
isReconnecting?: boolean;
|
||||
onCreateSession: () => void | Promise<string>;
|
||||
onSend: (message: string) => void | Promise<void>;
|
||||
onStop: () => void;
|
||||
@@ -28,13 +26,11 @@ export const ChatContainer = ({
|
||||
sessionId,
|
||||
isLoadingSession,
|
||||
isCreatingSession,
|
||||
isReconnecting,
|
||||
onCreateSession,
|
||||
onSend,
|
||||
onStop,
|
||||
headerSlot,
|
||||
}: ChatContainerProps) => {
|
||||
const isBusy = status === "streaming" || !!isReconnecting;
|
||||
const inputLayoutId = "copilot-2-chat-input";
|
||||
|
||||
return (
|
||||
@@ -60,8 +56,8 @@ export const ChatContainer = ({
|
||||
<ChatInput
|
||||
inputId="chat-input-session"
|
||||
onSend={onSend}
|
||||
disabled={isBusy}
|
||||
isStreaming={isBusy}
|
||||
disabled={status === "streaming"}
|
||||
isStreaming={status === "streaming"}
|
||||
onStop={onStop}
|
||||
placeholder="What else can I help with?"
|
||||
/>
|
||||
|
||||
@@ -1,713 +1,63 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import { ToolUIPart } from "ai";
|
||||
import {
|
||||
CheckCircleIcon,
|
||||
CircleDashedIcon,
|
||||
CircleIcon,
|
||||
FileIcon,
|
||||
FilesIcon,
|
||||
GearIcon,
|
||||
GlobeIcon,
|
||||
ListChecksIcon,
|
||||
MagnifyingGlassIcon,
|
||||
PencilSimpleIcon,
|
||||
TerminalIcon,
|
||||
TrashIcon,
|
||||
WarningDiamondIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { GearIcon } from "@phosphor-icons/react";
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
|
||||
import {
|
||||
ContentCodeBlock,
|
||||
ContentMessage,
|
||||
} from "../../components/ToolAccordion/AccordionContent";
|
||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||
|
||||
interface Props {
|
||||
part: ToolUIPart;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Tool name helpers */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function extractToolName(part: ToolUIPart): string {
|
||||
// ToolUIPart.type is "tool-{name}", extract the name portion.
|
||||
return part.type.replace(/^tool-/, "");
|
||||
}
|
||||
|
||||
function formatToolName(name: string): string {
|
||||
// "search_docs" → "Search docs", "Read" → "Read"
|
||||
return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase());
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Tool categorization */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
type ToolCategory =
|
||||
| "bash"
|
||||
| "web"
|
||||
| "file-read"
|
||||
| "file-write"
|
||||
| "file-delete"
|
||||
| "file-list"
|
||||
| "search"
|
||||
| "edit"
|
||||
| "todo"
|
||||
| "other";
|
||||
|
||||
function getToolCategory(toolName: string): ToolCategory {
|
||||
switch (toolName) {
|
||||
case "bash_exec":
|
||||
return "bash";
|
||||
case "web_fetch":
|
||||
case "WebSearch":
|
||||
case "WebFetch":
|
||||
return "web";
|
||||
case "read_workspace_file":
|
||||
case "Read":
|
||||
return "file-read";
|
||||
case "write_workspace_file":
|
||||
case "Write":
|
||||
return "file-write";
|
||||
case "delete_workspace_file":
|
||||
return "file-delete";
|
||||
case "list_workspace_files":
|
||||
case "Glob":
|
||||
return "file-list";
|
||||
case "Grep":
|
||||
return "search";
|
||||
case "Edit":
|
||||
return "edit";
|
||||
case "TodoWrite":
|
||||
return "todo";
|
||||
default:
|
||||
return "other";
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Tool icon */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function ToolIcon({
|
||||
category,
|
||||
isStreaming,
|
||||
isError,
|
||||
}: {
|
||||
category: ToolCategory;
|
||||
isStreaming: boolean;
|
||||
isError: boolean;
|
||||
}) {
|
||||
if (isError) {
|
||||
return (
|
||||
<WarningDiamondIcon size={14} weight="regular" className="text-red-500" />
|
||||
);
|
||||
}
|
||||
if (isStreaming) {
|
||||
return <OrbitLoader size={14} />;
|
||||
}
|
||||
|
||||
const iconClass = "text-neutral-400";
|
||||
switch (category) {
|
||||
case "bash":
|
||||
return <TerminalIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "web":
|
||||
return <GlobeIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "file-read":
|
||||
return <FileIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "file-write":
|
||||
return <FileIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "file-delete":
|
||||
return <TrashIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "file-list":
|
||||
return <FilesIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "search":
|
||||
return (
|
||||
<MagnifyingGlassIcon size={14} weight="regular" className={iconClass} />
|
||||
);
|
||||
case "edit":
|
||||
return (
|
||||
<PencilSimpleIcon size={14} weight="regular" className={iconClass} />
|
||||
);
|
||||
case "todo":
|
||||
return (
|
||||
<ListChecksIcon size={14} weight="regular" className={iconClass} />
|
||||
);
|
||||
default:
|
||||
return <GearIcon size={14} weight="regular" className={iconClass} />;
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Accordion icon (larger, for the accordion header) */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function AccordionIcon({ category }: { category: ToolCategory }) {
|
||||
switch (category) {
|
||||
case "bash":
|
||||
return <TerminalIcon size={32} weight="light" />;
|
||||
case "web":
|
||||
return <GlobeIcon size={32} weight="light" />;
|
||||
case "file-read":
|
||||
case "file-write":
|
||||
return <FileIcon size={32} weight="light" />;
|
||||
case "file-delete":
|
||||
return <TrashIcon size={32} weight="light" />;
|
||||
case "file-list":
|
||||
return <FilesIcon size={32} weight="light" />;
|
||||
case "search":
|
||||
return <MagnifyingGlassIcon size={32} weight="light" />;
|
||||
case "edit":
|
||||
return <PencilSimpleIcon size={32} weight="light" />;
|
||||
case "todo":
|
||||
return <ListChecksIcon size={32} weight="light" />;
|
||||
default:
|
||||
return <GearIcon size={32} weight="light" />;
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Input extraction */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function getInputSummary(toolName: string, input: unknown): string | null {
|
||||
if (!input || typeof input !== "object") return null;
|
||||
const inp = input as Record<string, unknown>;
|
||||
|
||||
switch (toolName) {
|
||||
case "bash_exec":
|
||||
return typeof inp.command === "string" ? inp.command : null;
|
||||
case "web_fetch":
|
||||
case "WebFetch":
|
||||
return typeof inp.url === "string" ? inp.url : null;
|
||||
case "WebSearch":
|
||||
return typeof inp.query === "string" ? inp.query : null;
|
||||
case "read_workspace_file":
|
||||
case "Read":
|
||||
return (
|
||||
(typeof inp.file_path === "string" ? inp.file_path : null) ??
|
||||
(typeof inp.path === "string" ? inp.path : null)
|
||||
);
|
||||
case "write_workspace_file":
|
||||
case "Write":
|
||||
return (
|
||||
(typeof inp.file_path === "string" ? inp.file_path : null) ??
|
||||
(typeof inp.path === "string" ? inp.path : null)
|
||||
);
|
||||
case "delete_workspace_file":
|
||||
return typeof inp.file_path === "string" ? inp.file_path : null;
|
||||
case "Glob":
|
||||
return typeof inp.pattern === "string" ? inp.pattern : null;
|
||||
case "Grep":
|
||||
return typeof inp.pattern === "string" ? inp.pattern : null;
|
||||
case "Edit":
|
||||
return typeof inp.file_path === "string" ? inp.file_path : null;
|
||||
case "TodoWrite": {
|
||||
// Extract the in-progress task name for the status line
|
||||
const todos = Array.isArray(inp.todos) ? inp.todos : [];
|
||||
const active = todos.find(
|
||||
(t: Record<string, unknown>) => t.status === "in_progress",
|
||||
);
|
||||
if (active && typeof active.activeForm === "string")
|
||||
return active.activeForm;
|
||||
if (active && typeof active.content === "string") return active.content;
|
||||
return null;
|
||||
}
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function truncate(text: string, maxLen: number): string {
|
||||
if (text.length <= maxLen) return text;
|
||||
return text.slice(0, maxLen).trimEnd() + "…";
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Animation text */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function getAnimationText(part: ToolUIPart, category: ToolCategory): string {
|
||||
const toolName = extractToolName(part);
|
||||
const summary = getInputSummary(toolName, part.input);
|
||||
const shortSummary = summary ? truncate(summary, 60) : null;
|
||||
function getAnimationText(part: ToolUIPart): string {
|
||||
const label = formatToolName(extractToolName(part));
|
||||
|
||||
switch (part.state) {
|
||||
case "input-streaming":
|
||||
case "input-available": {
|
||||
switch (category) {
|
||||
case "bash":
|
||||
return shortSummary ? `Running: ${shortSummary}` : "Running command…";
|
||||
case "web":
|
||||
if (toolName === "WebSearch") {
|
||||
return shortSummary
|
||||
? `Searching "${shortSummary}"`
|
||||
: "Searching the web…";
|
||||
}
|
||||
return shortSummary
|
||||
? `Fetching ${shortSummary}`
|
||||
: "Fetching web content…";
|
||||
case "file-read":
|
||||
return shortSummary ? `Reading ${shortSummary}` : "Reading file…";
|
||||
case "file-write":
|
||||
return shortSummary ? `Writing ${shortSummary}` : "Writing file…";
|
||||
case "file-delete":
|
||||
return shortSummary ? `Deleting ${shortSummary}` : "Deleting file…";
|
||||
case "file-list":
|
||||
return shortSummary ? `Listing ${shortSummary}` : "Listing files…";
|
||||
case "search":
|
||||
return shortSummary
|
||||
? `Searching for "${shortSummary}"`
|
||||
: "Searching…";
|
||||
case "edit":
|
||||
return shortSummary ? `Editing ${shortSummary}` : "Editing file…";
|
||||
case "todo":
|
||||
return shortSummary ? `${shortSummary}` : "Updating task list…";
|
||||
default:
|
||||
return `Running ${formatToolName(toolName)}…`;
|
||||
}
|
||||
}
|
||||
case "output-available": {
|
||||
switch (category) {
|
||||
case "bash": {
|
||||
const exitCode = getExitCode(part.output);
|
||||
if (exitCode !== null && exitCode !== 0) {
|
||||
return `Command exited with code ${exitCode}`;
|
||||
}
|
||||
return shortSummary ? `Ran: ${shortSummary}` : "Command completed";
|
||||
}
|
||||
case "web":
|
||||
if (toolName === "WebSearch") {
|
||||
return shortSummary
|
||||
? `Searched "${shortSummary}"`
|
||||
: "Web search completed";
|
||||
}
|
||||
return shortSummary
|
||||
? `Fetched ${shortSummary}`
|
||||
: "Fetched web content";
|
||||
case "file-read":
|
||||
return shortSummary ? `Read ${shortSummary}` : "File read completed";
|
||||
case "file-write":
|
||||
return shortSummary ? `Wrote ${shortSummary}` : "File written";
|
||||
case "file-delete":
|
||||
return shortSummary ? `Deleted ${shortSummary}` : "File deleted";
|
||||
case "file-list":
|
||||
return "Listed files";
|
||||
case "search":
|
||||
return shortSummary
|
||||
? `Searched for "${shortSummary}"`
|
||||
: "Search completed";
|
||||
case "edit":
|
||||
return shortSummary ? `Edited ${shortSummary}` : "Edit completed";
|
||||
case "todo":
|
||||
return "Updated task list";
|
||||
default:
|
||||
return `${formatToolName(toolName)} completed`;
|
||||
}
|
||||
}
|
||||
case "output-error": {
|
||||
switch (category) {
|
||||
case "bash":
|
||||
return "Command failed";
|
||||
case "web":
|
||||
return toolName === "WebSearch" ? "Search failed" : "Fetch failed";
|
||||
default:
|
||||
return `${formatToolName(toolName)} failed`;
|
||||
}
|
||||
}
|
||||
case "input-available":
|
||||
return `Running ${label}…`;
|
||||
case "output-available":
|
||||
return `${label} completed`;
|
||||
case "output-error":
|
||||
return `${label} failed`;
|
||||
default:
|
||||
return `Running ${formatToolName(toolName)}…`;
|
||||
return `Running ${label}…`;
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Output parsing helpers */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function parseOutput(output: unknown): Record<string, unknown> | null {
|
||||
if (!output) return null;
|
||||
if (typeof output === "object") return output as Record<string, unknown>;
|
||||
if (typeof output === "string") {
|
||||
const trimmed = output.trim();
|
||||
if (!trimmed) return null;
|
||||
try {
|
||||
const parsed = JSON.parse(trimmed);
|
||||
if (
|
||||
typeof parsed === "object" &&
|
||||
parsed !== null &&
|
||||
!Array.isArray(parsed)
|
||||
)
|
||||
return parsed;
|
||||
} catch {
|
||||
// Return as a message wrapper for plain text output
|
||||
return { _raw: trimmed };
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract text from MCP-style content blocks.
|
||||
* SDK built-in tools (WebSearch, etc.) may return `{content: [{type:"text", text:"..."}]}`.
|
||||
*/
|
||||
function extractMcpText(output: Record<string, unknown>): string | null {
|
||||
if (Array.isArray(output.content)) {
|
||||
const texts = (output.content as Array<Record<string, unknown>>)
|
||||
.filter((b) => b.type === "text" && typeof b.text === "string")
|
||||
.map((b) => b.text as string);
|
||||
if (texts.length > 0) return texts.join("\n");
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function getExitCode(output: unknown): number | null {
|
||||
const parsed = parseOutput(output);
|
||||
if (!parsed) return null;
|
||||
if (typeof parsed.exit_code === "number") return parsed.exit_code;
|
||||
return null;
|
||||
}
|
||||
|
||||
function getStringField(
|
||||
obj: Record<string, unknown>,
|
||||
...keys: string[]
|
||||
): string | null {
|
||||
for (const key of keys) {
|
||||
if (typeof obj[key] === "string" && obj[key].length > 0)
|
||||
return obj[key] as string;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Accordion content per tool category */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
interface AccordionData {
|
||||
title: string;
|
||||
description?: string;
|
||||
content: React.ReactNode;
|
||||
}
|
||||
|
||||
function getBashAccordionData(
|
||||
input: unknown,
|
||||
output: Record<string, unknown>,
|
||||
): AccordionData {
|
||||
const inp = (input && typeof input === "object" ? input : {}) as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
const command = typeof inp.command === "string" ? inp.command : "Command";
|
||||
|
||||
const stdout = getStringField(output, "stdout");
|
||||
const stderr = getStringField(output, "stderr");
|
||||
const exitCode =
|
||||
typeof output.exit_code === "number" ? output.exit_code : null;
|
||||
const timedOut = output.timed_out === true;
|
||||
const message = getStringField(output, "message");
|
||||
|
||||
const title = timedOut
|
||||
? "Command timed out"
|
||||
: exitCode !== null && exitCode !== 0
|
||||
? `Command failed (exit ${exitCode})`
|
||||
: "Command output";
|
||||
|
||||
return {
|
||||
title,
|
||||
description: truncate(command, 80),
|
||||
content: (
|
||||
<div className="space-y-2">
|
||||
{stdout && (
|
||||
<div>
|
||||
<p className="mb-1 text-xs font-medium text-slate-500">stdout</p>
|
||||
<ContentCodeBlock>{truncate(stdout, 2000)}</ContentCodeBlock>
|
||||
</div>
|
||||
)}
|
||||
{stderr && (
|
||||
<div>
|
||||
<p className="mb-1 text-xs font-medium text-slate-500">stderr</p>
|
||||
<ContentCodeBlock>{truncate(stderr, 1000)}</ContentCodeBlock>
|
||||
</div>
|
||||
)}
|
||||
{!stdout && !stderr && message && (
|
||||
<ContentMessage>{message}</ContentMessage>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
function getWebAccordionData(
|
||||
input: unknown,
|
||||
output: Record<string, unknown>,
|
||||
): AccordionData {
|
||||
const inp = (input && typeof input === "object" ? input : {}) as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
const url =
|
||||
getStringField(inp as Record<string, unknown>, "url", "query") ??
|
||||
"Web content";
|
||||
|
||||
// Try direct string fields first, then MCP content blocks, then raw JSON
|
||||
let content = getStringField(output, "content", "text", "_raw");
|
||||
if (!content) content = extractMcpText(output);
|
||||
if (!content) {
|
||||
// Fallback: render the raw JSON so the accordion isn't empty
|
||||
try {
|
||||
const raw = JSON.stringify(output, null, 2);
|
||||
if (raw !== "{}") content = raw;
|
||||
} catch {
|
||||
/* empty */
|
||||
}
|
||||
}
|
||||
|
||||
const statusCode =
|
||||
typeof output.status_code === "number" ? output.status_code : null;
|
||||
const message = getStringField(output, "message");
|
||||
|
||||
return {
|
||||
title: statusCode
|
||||
? `Response (${statusCode})`
|
||||
: url
|
||||
? "Web fetch"
|
||||
: "Search results",
|
||||
description: truncate(url, 80),
|
||||
content: content ? (
|
||||
<ContentCodeBlock>{truncate(content, 2000)}</ContentCodeBlock>
|
||||
) : message ? (
|
||||
<ContentMessage>{message}</ContentMessage>
|
||||
) : Object.keys(output).length > 0 ? (
|
||||
<ContentCodeBlock>
|
||||
{truncate(JSON.stringify(output, null, 2), 2000)}
|
||||
</ContentCodeBlock>
|
||||
) : null,
|
||||
};
|
||||
}
|
||||
|
||||
function getFileAccordionData(
|
||||
input: unknown,
|
||||
output: Record<string, unknown>,
|
||||
): AccordionData {
|
||||
const inp = (input && typeof input === "object" ? input : {}) as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
const filePath =
|
||||
getStringField(
|
||||
inp as Record<string, unknown>,
|
||||
"file_path",
|
||||
"path",
|
||||
"pattern",
|
||||
) ?? "File";
|
||||
const content = getStringField(output, "content", "text", "_raw");
|
||||
const message = getStringField(output, "message");
|
||||
// For Glob/list results, try to show file list
|
||||
const files = Array.isArray(output.files)
|
||||
? output.files.filter((f: unknown): f is string => typeof f === "string")
|
||||
: null;
|
||||
|
||||
return {
|
||||
title: message ?? "File output",
|
||||
description: truncate(filePath, 80),
|
||||
content: (
|
||||
<div className="space-y-2">
|
||||
{content && (
|
||||
<ContentCodeBlock>{truncate(content, 2000)}</ContentCodeBlock>
|
||||
)}
|
||||
{files && files.length > 0 && (
|
||||
<ContentCodeBlock>
|
||||
{truncate(files.join("\n"), 2000)}
|
||||
</ContentCodeBlock>
|
||||
)}
|
||||
{!content && !files && message && (
|
||||
<ContentMessage>{message}</ContentMessage>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
interface TodoItem {
|
||||
content: string;
|
||||
status: "pending" | "in_progress" | "completed";
|
||||
activeForm?: string;
|
||||
}
|
||||
|
||||
function getTodoAccordionData(input: unknown): AccordionData {
|
||||
const inp = (input && typeof input === "object" ? input : {}) as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
const todos: TodoItem[] = Array.isArray(inp.todos)
|
||||
? inp.todos.filter(
|
||||
(t: unknown): t is TodoItem =>
|
||||
typeof t === "object" &&
|
||||
t !== null &&
|
||||
typeof (t as TodoItem).content === "string",
|
||||
)
|
||||
: [];
|
||||
|
||||
const completed = todos.filter((t) => t.status === "completed").length;
|
||||
const total = todos.length;
|
||||
|
||||
return {
|
||||
title: "Task list",
|
||||
description: `${completed}/${total} completed`,
|
||||
content: (
|
||||
<div className="space-y-1 py-1">
|
||||
{todos.map((todo, i) => (
|
||||
<div key={i} className="flex items-start gap-2 text-xs">
|
||||
<span className="mt-0.5 flex-shrink-0">
|
||||
{todo.status === "completed" ? (
|
||||
<CheckCircleIcon
|
||||
size={14}
|
||||
weight="fill"
|
||||
className="text-green-500"
|
||||
/>
|
||||
) : todo.status === "in_progress" ? (
|
||||
<CircleDashedIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="text-blue-500"
|
||||
/>
|
||||
) : (
|
||||
<CircleIcon
|
||||
size={14}
|
||||
weight="regular"
|
||||
className="text-neutral-400"
|
||||
/>
|
||||
)}
|
||||
</span>
|
||||
<span
|
||||
className={
|
||||
todo.status === "completed"
|
||||
? "text-muted-foreground line-through"
|
||||
: todo.status === "in_progress"
|
||||
? "font-medium text-foreground"
|
||||
: "text-muted-foreground"
|
||||
}
|
||||
>
|
||||
{todo.content}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
function getDefaultAccordionData(
|
||||
output: Record<string, unknown>,
|
||||
): AccordionData {
|
||||
const message = getStringField(output, "message");
|
||||
const raw = output._raw;
|
||||
const mcpText = extractMcpText(output);
|
||||
|
||||
let displayContent: string;
|
||||
if (typeof raw === "string") {
|
||||
displayContent = raw;
|
||||
} else if (mcpText) {
|
||||
displayContent = mcpText;
|
||||
} else if (message) {
|
||||
displayContent = message;
|
||||
} else {
|
||||
try {
|
||||
displayContent = JSON.stringify(output, null, 2);
|
||||
} catch {
|
||||
displayContent = String(output);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
title: "Output",
|
||||
description: message ?? undefined,
|
||||
content: (
|
||||
<ContentCodeBlock>{truncate(displayContent, 2000)}</ContentCodeBlock>
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
function getAccordionData(
|
||||
category: ToolCategory,
|
||||
input: unknown,
|
||||
output: Record<string, unknown>,
|
||||
): AccordionData {
|
||||
switch (category) {
|
||||
case "bash":
|
||||
return getBashAccordionData(input, output);
|
||||
case "web":
|
||||
return getWebAccordionData(input, output);
|
||||
case "file-read":
|
||||
case "file-write":
|
||||
case "file-delete":
|
||||
case "file-list":
|
||||
case "search":
|
||||
case "edit":
|
||||
return getFileAccordionData(input, output);
|
||||
case "todo":
|
||||
return getTodoAccordionData(input);
|
||||
default:
|
||||
return getDefaultAccordionData(output);
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Component */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
export function GenericTool({ part }: Props) {
|
||||
const toolName = extractToolName(part);
|
||||
const category = getToolCategory(toolName);
|
||||
const isStreaming =
|
||||
part.state === "input-streaming" || part.state === "input-available";
|
||||
const isError = part.state === "output-error";
|
||||
const text = getAnimationText(part, category);
|
||||
|
||||
const output = parseOutput(part.output);
|
||||
const hasOutput =
|
||||
part.state === "output-available" &&
|
||||
!!output &&
|
||||
Object.keys(output).length > 0;
|
||||
const hasError = isError && !!output;
|
||||
|
||||
// TodoWrite: always show accordion from input (the todo list lives in input)
|
||||
const hasTodoInput =
|
||||
category === "todo" &&
|
||||
part.input &&
|
||||
typeof part.input === "object" &&
|
||||
Array.isArray((part.input as Record<string, unknown>).todos);
|
||||
|
||||
const showAccordion = hasOutput || hasError || hasTodoInput;
|
||||
const accordionData = showAccordion
|
||||
? getAccordionData(category, part.input, output ?? {})
|
||||
: null;
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<ToolIcon
|
||||
category={category}
|
||||
isStreaming={isStreaming}
|
||||
isError={isError}
|
||||
<GearIcon
|
||||
size={14}
|
||||
weight="regular"
|
||||
className={
|
||||
isError
|
||||
? "text-red-500"
|
||||
: isStreaming
|
||||
? "animate-spin text-neutral-500"
|
||||
: "text-neutral-400"
|
||||
}
|
||||
/>
|
||||
<MorphingTextAnimation
|
||||
text={text}
|
||||
text={getAnimationText(part)}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{showAccordion && accordionData ? (
|
||||
<ToolAccordion
|
||||
icon={<AccordionIcon category={category} />}
|
||||
title={accordionData.title}
|
||||
description={accordionData.description}
|
||||
titleClassName={isError ? "text-red-500" : undefined}
|
||||
>
|
||||
{accordionData.content}
|
||||
</ToolAccordion>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -50,14 +50,6 @@ export function useChatSession() {
|
||||
);
|
||||
}, [sessionQuery.data, sessionId]);
|
||||
|
||||
// Expose active_stream info so the caller can trigger manual resume
|
||||
// after hydration completes (rather than relying on AI SDK's built-in
|
||||
// resume which fires before hydration).
|
||||
const hasActiveStream = useMemo(() => {
|
||||
if (sessionQuery.data?.status !== 200) return false;
|
||||
return !!sessionQuery.data.data.active_stream;
|
||||
}, [sessionQuery.data]);
|
||||
|
||||
const { mutateAsync: createSessionMutation, isPending: isCreatingSession } =
|
||||
usePostV2CreateSession({
|
||||
mutation: {
|
||||
@@ -110,7 +102,6 @@ export function useChatSession() {
|
||||
sessionId,
|
||||
setSessionId,
|
||||
hydratedMessages,
|
||||
hasActiveStream,
|
||||
isLoadingSession: sessionQuery.isLoading,
|
||||
createSession,
|
||||
isCreatingSession,
|
||||
|
||||
@@ -29,7 +29,6 @@ export function useCopilotPage() {
|
||||
sessionId,
|
||||
setSessionId,
|
||||
hydratedMessages,
|
||||
hasActiveStream,
|
||||
isLoadingSession,
|
||||
createSession,
|
||||
isCreatingSession,
|
||||
@@ -81,31 +80,14 @@ export function useCopilotPage() {
|
||||
},
|
||||
};
|
||||
},
|
||||
// Resume: GET goes to the same URL as POST (backend uses
|
||||
// method to distinguish). Override the default formula which
|
||||
// would append /{chatId}/stream to the existing path.
|
||||
prepareReconnectToStreamRequest: () => ({
|
||||
api: `/api/chat/sessions/${sessionId}/stream`,
|
||||
}),
|
||||
})
|
||||
: null,
|
||||
[sessionId],
|
||||
);
|
||||
|
||||
const {
|
||||
messages,
|
||||
sendMessage,
|
||||
stop,
|
||||
status,
|
||||
error,
|
||||
setMessages,
|
||||
resumeStream,
|
||||
} = useChat({
|
||||
const { messages, sendMessage, stop, status, error, setMessages } = useChat({
|
||||
id: sessionId ?? undefined,
|
||||
transport: transport ?? undefined,
|
||||
// Don't use resume: true — it fires before hydration completes, causing
|
||||
// the hydrated messages to overwrite the resumed stream. Instead we
|
||||
// call resumeStream() manually after hydration + active_stream detection.
|
||||
});
|
||||
|
||||
// Abort the stream if the backend doesn't start sending data within 12s.
|
||||
@@ -126,31 +108,13 @@ export function useCopilotPage() {
|
||||
return () => clearTimeout(timer);
|
||||
}, [status]);
|
||||
|
||||
// Hydrate messages from the REST session endpoint.
|
||||
// Skip hydration while streaming to avoid overwriting the live stream.
|
||||
useEffect(() => {
|
||||
if (!hydratedMessages || hydratedMessages.length === 0) return;
|
||||
if (status === "streaming" || status === "submitted") return;
|
||||
setMessages((prev) => {
|
||||
if (prev.length >= hydratedMessages.length) return prev;
|
||||
return hydratedMessages;
|
||||
});
|
||||
}, [hydratedMessages, setMessages, status]);
|
||||
|
||||
// Resume an active stream AFTER hydration completes.
|
||||
// The backend returns active_stream info when a task is still running.
|
||||
// We wait for hydration so the AI SDK has the conversation history
|
||||
// before the resumed stream appends the in-progress assistant message.
|
||||
const hasResumedRef = useRef<string | null>(null);
|
||||
useEffect(() => {
|
||||
if (!hasActiveStream || !sessionId) return;
|
||||
if (!hydratedMessages || hydratedMessages.length === 0) return;
|
||||
if (status === "streaming" || status === "submitted") return;
|
||||
// Only resume once per session to avoid re-triggering after stream ends
|
||||
if (hasResumedRef.current === sessionId) return;
|
||||
hasResumedRef.current = sessionId;
|
||||
resumeStream();
|
||||
}, [hasActiveStream, sessionId, hydratedMessages, status, resumeStream]);
|
||||
}, [hydratedMessages, setMessages]);
|
||||
|
||||
// Poll session endpoint when a long-running tool (create_agent, edit_agent)
|
||||
// is in progress. When the backend completes, the session data will contain
|
||||
@@ -233,18 +197,12 @@ export function useCopilotPage() {
|
||||
}
|
||||
}, [isDeleting]);
|
||||
|
||||
// True while we know the backend has an active stream but haven't
|
||||
// reconnected yet. Used to disable the send button and show stop UI.
|
||||
const isReconnecting =
|
||||
hasActiveStream && status !== "streaming" && status !== "submitted";
|
||||
|
||||
return {
|
||||
sessionId,
|
||||
messages,
|
||||
status,
|
||||
error,
|
||||
stop,
|
||||
isReconnecting,
|
||||
isLoadingSession,
|
||||
isCreatingSession,
|
||||
isUserLoading,
|
||||
|
||||
Reference in New Issue
Block a user