mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-24 03:00:28 -05:00
Compare commits
3 Commits
fix/copilo
...
feat/copil
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
83422d9dc8 | ||
|
|
4e84be021a | ||
|
|
81b20ff9d8 |
@@ -18,7 +18,7 @@ from backend.copilot.completion_handler import (
|
||||
process_operation_success,
|
||||
)
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_task
|
||||
from backend.copilot.executor.utils import enqueue_copilot_task
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
@@ -50,7 +50,6 @@ from backend.copilot.tools.models import (
|
||||
OperationPendingResponse,
|
||||
OperationStartedResponse,
|
||||
SetupRequirementsResponse,
|
||||
SuggestedGoalResponse,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
@@ -132,14 +131,6 @@ class ListSessionsResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class CancelTaskResponse(BaseModel):
|
||||
"""Response model for the cancel task endpoint."""
|
||||
|
||||
cancelled: bool
|
||||
task_id: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class OperationCompleteRequest(BaseModel):
|
||||
"""Request model for external completion webhook."""
|
||||
|
||||
@@ -322,57 +313,6 @@ async def get_session(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/cancel",
|
||||
status_code=200,
|
||||
)
|
||||
async def cancel_session_task(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CancelTaskResponse:
|
||||
"""Cancel the active streaming task for a session.
|
||||
|
||||
Publishes a cancel event to the executor via RabbitMQ FANOUT, then
|
||||
polls Redis until the task status flips from ``running`` or a timeout
|
||||
(5 s) is reached. Returns only after the cancellation is confirmed.
|
||||
"""
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
active_task, _ = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
if not active_task:
|
||||
return CancelTaskResponse(cancelled=False, reason="no_active_task")
|
||||
|
||||
task_id = active_task.task_id
|
||||
await enqueue_cancel_task(task_id)
|
||||
logger.info(
|
||||
f"[CANCEL] Published cancel for task ...{task_id[-8:]} "
|
||||
f"session ...{session_id[-8:]}"
|
||||
)
|
||||
|
||||
# Poll until the executor confirms the task is no longer running.
|
||||
# Keep max_wait below typical reverse-proxy read timeouts.
|
||||
poll_interval = 0.5
|
||||
max_wait = 5.0
|
||||
waited = 0.0
|
||||
while waited < max_wait:
|
||||
await asyncio.sleep(poll_interval)
|
||||
waited += poll_interval
|
||||
task = await stream_registry.get_task(task_id)
|
||||
if task is None or task.status != "running":
|
||||
logger.info(
|
||||
f"[CANCEL] Task ...{task_id[-8:]} confirmed stopped "
|
||||
f"(status={task.status if task else 'gone'}) after {waited:.1f}s"
|
||||
)
|
||||
return CancelTaskResponse(cancelled=True, task_id=task_id)
|
||||
|
||||
logger.warning(f"[CANCEL] Task ...{task_id[-8:]} not confirmed after {max_wait}s")
|
||||
return CancelTaskResponse(
|
||||
cancelled=True, task_id=task_id, reason="cancel_published_not_confirmed"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
@@ -1044,7 +984,6 @@ ToolResponseUnion = (
|
||||
| AgentPreviewResponse
|
||||
| AgentSavedResponse
|
||||
| ClarificationNeededResponse
|
||||
| SuggestedGoalResponse
|
||||
| BlockListResponse
|
||||
| BlockDetailsResponse
|
||||
| BlockOutputResponse
|
||||
|
||||
@@ -205,20 +205,3 @@ async def enqueue_copilot_task(
|
||||
message=entry.model_dump_json(),
|
||||
exchange=COPILOT_EXECUTION_EXCHANGE,
|
||||
)
|
||||
|
||||
|
||||
async def enqueue_cancel_task(task_id: str) -> None:
|
||||
"""Publish a cancel request for a running CoPilot task.
|
||||
|
||||
Sends a ``CancelCoPilotEvent`` to the FANOUT exchange so all executor
|
||||
pods receive the cancellation signal.
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
event = CancelCoPilotEvent(task_id=task_id)
|
||||
queue_client = await get_async_copilot_queue()
|
||||
await queue_client.publish_message(
|
||||
routing_key="", # FANOUT ignores routing key
|
||||
message=event.model_dump_json(),
|
||||
exchange=COPILOT_CANCEL_EXCHANGE,
|
||||
)
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
"""Tests for parallel tool call execution in CoPilot.
|
||||
|
||||
These tests mock _yield_tool_call to avoid importing the full copilot stack
|
||||
which requires Prisma, DB connections, etc.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls_run_concurrently():
|
||||
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
|
||||
# Import here to allow module-level mocking if needed
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
n_tools = 3
|
||||
delay_per_tool = 0.2
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"tool_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(n_tools)
|
||||
]
|
||||
|
||||
# Minimal session mock
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
original_yield = None
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"],
|
||||
toolName=tc_list[idx]["function"]["name"],
|
||||
input={},
|
||||
)
|
||||
await asyncio.sleep(delay_per_tool)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"],
|
||||
toolName=tc_list[idx]["function"]["name"],
|
||||
output="{}",
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
original_yield = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
start = time.monotonic()
|
||||
events = []
|
||||
async for event in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
events.append(event)
|
||||
elapsed = time.monotonic() - start
|
||||
finally:
|
||||
svc._yield_tool_call = original_yield
|
||||
|
||||
assert len(events) == n_tools * 2
|
||||
# Parallel: should take ~delay, not ~n*delay
|
||||
assert elapsed < delay_per_tool * (
|
||||
n_tools - 0.5
|
||||
), f"Took {elapsed:.2f}s, expected parallel (~{delay_per_tool}s)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_tool_call_works():
|
||||
"""Single tool call should work identically."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": "call_0",
|
||||
"type": "function",
|
||||
"function": {"name": "t", "arguments": "{}"},
|
||||
}
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
|
||||
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
events = [
|
||||
e
|
||||
async for e in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
)
|
||||
]
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
assert len(events) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retryable_error_propagates():
|
||||
"""Retryable errors should be raised after all tools finish."""
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
if idx == 1:
|
||||
raise KeyError("bad")
|
||||
from backend.copilot.response_model import StreamToolInputAvailable
|
||||
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName="t_0", input={}
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName="t_0", output="{}"
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
events = []
|
||||
with pytest.raises(KeyError):
|
||||
async for event in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
events.append(event)
|
||||
# First tool's events should still be yielded
|
||||
assert any(isinstance(e, StreamToolOutputAvailable) for e in events)
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_lock_shared():
|
||||
"""All parallel tools should receive the same lock instance."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
observed_locks = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
observed_locks.append(lock)
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", output="{}"
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
async for _ in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
pass
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
assert len(observed_locks) == 3
|
||||
assert observed_locks[0] is observed_locks[1] is observed_locks[2]
|
||||
assert isinstance(observed_locks[0], asyncio.Lock)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancellation_cleans_up():
|
||||
"""Generator close should cancel in-flight tasks."""
|
||||
from backend.copilot.response_model import StreamToolInputAvailable
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
started.set()
|
||||
await asyncio.sleep(10) # simulate long-running
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
gen = _execute_tool_calls_parallel(tool_calls, cast(Any, FakeSession()))
|
||||
await gen.__anext__() # get first event
|
||||
await started.wait()
|
||||
await gen.aclose() # close generator
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
# If we get here without hanging, cleanup worked
|
||||
@@ -16,6 +16,7 @@ from .tool_adapter import (
|
||||
DANGEROUS_PATTERNS,
|
||||
MCP_TOOL_PREFIX,
|
||||
WORKSPACE_SCOPED_TOOLS,
|
||||
get_sandbox_manager,
|
||||
stash_pending_tool_output,
|
||||
)
|
||||
|
||||
@@ -97,8 +98,10 @@ def _validate_tool_access(
|
||||
"Use the CoPilot-specific MCP tools instead."
|
||||
)
|
||||
|
||||
# Workspace-scoped tools: allowed only within the SDK workspace directory
|
||||
if tool_name in WORKSPACE_SCOPED_TOOLS:
|
||||
# Workspace-scoped tools: allowed only within the SDK workspace directory.
|
||||
# When e2b is enabled, these SDK built-in tools are disabled (replaced by
|
||||
# MCP e2b file tools), so skip workspace path validation.
|
||||
if tool_name in WORKSPACE_SCOPED_TOOLS and get_sandbox_manager() is None:
|
||||
return _validate_workspace_path(tool_name, tool_input, sdk_cwd)
|
||||
|
||||
# Check for dangerous patterns in tool input
|
||||
|
||||
@@ -58,6 +58,9 @@ from .transcript import (
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
# SDK built-in file tools to disable when e2b is active (replaced by MCP tools)
|
||||
_E2B_DISALLOWED_SDK_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep"]
|
||||
|
||||
# Set to hold background tasks to prevent garbage collection
|
||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
@@ -98,6 +101,23 @@ _SDK_TOOL_SUPPLEMENT = """
|
||||
is delivered to the user via a background stream.
|
||||
"""
|
||||
|
||||
_SDK_TOOL_SUPPLEMENT_E2B = """
|
||||
|
||||
## Tool notes
|
||||
|
||||
- The SDK built-in Bash, Read, Write, Edit, Glob, and Grep tools are NOT available.
|
||||
Use the MCP tools instead: `bash_exec`, `read_file`, `write_file`, `edit_file`,
|
||||
`glob_files`, `grep_files`.
|
||||
- **All tools share a single sandbox**: The sandbox is a microVM with a shared
|
||||
filesystem at /home/user/. Files created by any tool are accessible to all others.
|
||||
Network access IS available (pip install, curl, etc.).
|
||||
- **Persistent storage**: Use `save_to_workspace` to persist sandbox files to cloud
|
||||
storage, and `load_from_workspace` to bring workspace files into the sandbox.
|
||||
- Long-running tools (create_agent, edit_agent, etc.) are handled
|
||||
asynchronously. You will receive an immediate response; the actual result
|
||||
is delivered to the user via a background stream.
|
||||
"""
|
||||
|
||||
|
||||
def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
|
||||
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
|
||||
@@ -453,12 +473,33 @@ async def stream_chat_completion_sdk(
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# Check if e2b sandbox is enabled for this user
|
||||
sandbox_mgr = None
|
||||
use_e2b = False
|
||||
try:
|
||||
from backend.util.feature_flag import Flag
|
||||
from backend.util.feature_flag import is_feature_enabled as _is_flag_enabled
|
||||
from backend.util.settings import Config as AppConfig
|
||||
|
||||
app_config = AppConfig()
|
||||
use_e2b = await _is_flag_enabled(
|
||||
Flag.COPILOT_E2B,
|
||||
user_id or "anonymous",
|
||||
default=app_config.copilot_use_e2b,
|
||||
)
|
||||
if use_e2b:
|
||||
from backend.copilot.tools.e2b_sandbox import CoPilotSandboxManager
|
||||
|
||||
sandbox_mgr = CoPilotSandboxManager()
|
||||
except Exception as e:
|
||||
logger.warning(f"[SDK] Failed to initialize e2b sandbox: {e}")
|
||||
|
||||
# Build system prompt (reuses non-SDK path with Langfuse support)
|
||||
has_history = len(session.messages) > 1
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=has_history
|
||||
)
|
||||
system_prompt += _SDK_TOOL_SUPPLEMENT
|
||||
system_prompt += _SDK_TOOL_SUPPLEMENT_E2B if use_e2b else _SDK_TOOL_SUPPLEMENT
|
||||
message_id = str(uuid.uuid4())
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
@@ -480,6 +521,7 @@ async def stream_chat_completion_sdk(
|
||||
user_id,
|
||||
session,
|
||||
long_running_callback=_build_long_running_callback(user_id),
|
||||
sandbox_manager=sandbox_mgr,
|
||||
)
|
||||
try:
|
||||
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||
@@ -531,11 +573,21 @@ async def stream_chat_completion_sdk(
|
||||
f"msg_count={transcript_msg_count})"
|
||||
)
|
||||
|
||||
# When e2b is active, disable SDK built-in file tools
|
||||
# (replaced by MCP e2b tools) and remove them from allowed list
|
||||
effective_disallowed = list(SDK_DISALLOWED_TOOLS)
|
||||
effective_allowed = list(COPILOT_TOOL_NAMES)
|
||||
if use_e2b:
|
||||
effective_disallowed.extend(_E2B_DISALLOWED_SDK_TOOLS)
|
||||
effective_allowed = [
|
||||
t for t in effective_allowed if t not in _E2B_DISALLOWED_SDK_TOOLS
|
||||
]
|
||||
|
||||
sdk_options_kwargs: dict[str, Any] = {
|
||||
"system_prompt": system_prompt,
|
||||
"mcp_servers": {"copilot": mcp_server},
|
||||
"allowed_tools": COPILOT_TOOL_NAMES,
|
||||
"disallowed_tools": SDK_DISALLOWED_TOOLS,
|
||||
"allowed_tools": effective_allowed,
|
||||
"disallowed_tools": effective_disallowed,
|
||||
"hooks": security_hooks,
|
||||
"cwd": sdk_cwd,
|
||||
"max_buffer_size": config.claude_agent_max_buffer_size,
|
||||
@@ -749,6 +801,11 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
yield StreamFinish()
|
||||
finally:
|
||||
if sandbox_mgr:
|
||||
try:
|
||||
await sandbox_mgr.dispose_all()
|
||||
except Exception as e:
|
||||
logger.warning(f"[SDK] Failed to dispose e2b sandboxes: {e}")
|
||||
if sdk_cwd:
|
||||
_cleanup_sdk_tool_results(sdk_cwd)
|
||||
|
||||
|
||||
@@ -42,7 +42,8 @@ _current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
_pending_tool_outputs: ContextVar[dict[str, list[str]]] = ContextVar(
|
||||
"pending_tool_outputs", default=None # type: ignore[arg-type]
|
||||
"pending_tool_outputs",
|
||||
default=None, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Callback type for delegating long-running tools to the non-SDK infrastructure.
|
||||
@@ -56,11 +57,15 @@ _long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar(
|
||||
"long_running_callback", default=None
|
||||
)
|
||||
|
||||
# ContextVar for the e2b sandbox manager (set when e2b is enabled).
|
||||
_sandbox_manager: ContextVar[Any | None] = ContextVar("sandbox_manager", default=None)
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
long_running_callback: LongRunningCallback | None = None,
|
||||
sandbox_manager: Any | None = None,
|
||||
) -> None:
|
||||
"""Set the execution context for tool calls.
|
||||
|
||||
@@ -72,11 +77,13 @@ def set_execution_context(
|
||||
session: Current chat session.
|
||||
long_running_callback: Optional callback to delegate long-running tools
|
||||
to the non-SDK background infrastructure (stream_registry + Redis).
|
||||
sandbox_manager: Optional CoPilotSandboxManager for e2b sandbox access.
|
||||
"""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_pending_tool_outputs.set({})
|
||||
_long_running_callback.set(long_running_callback)
|
||||
_sandbox_manager.set(sandbox_manager)
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
@@ -87,6 +94,11 @@ def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
)
|
||||
|
||||
|
||||
def get_sandbox_manager() -> Any | None:
|
||||
"""Get the current e2b sandbox manager from execution context."""
|
||||
return _sandbox_manager.get(None)
|
||||
|
||||
|
||||
def pop_pending_tool_output(tool_name: str) -> str | None:
|
||||
"""Pop and return the oldest stashed output for *tool_name*.
|
||||
|
||||
|
||||
@@ -118,8 +118,6 @@ Adapt flexibly to the conversation context. Not every interaction requires all s
|
||||
- Find reusable components with `find_block`
|
||||
- Create custom solutions with `create_agent` if nothing suitable exists
|
||||
- Modify existing library agents with `edit_agent`
|
||||
- **When `create_agent` returns `suggested_goal`**: Present the suggestion to the user and ask "Would you like me to proceed with this refined goal?" If they accept, call `create_agent` again with the suggested goal.
|
||||
- **When `create_agent` returns `clarifying_questions`**: After the user answers, call `create_agent` again with the original description AND the answers in the `context` parameter.
|
||||
|
||||
5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`.
|
||||
|
||||
@@ -166,11 +164,6 @@ Adapt flexibly to the conversation context. Not every interaction requires all s
|
||||
- Use `add_understanding` to capture valuable business context
|
||||
- When tool calls fail, try alternative approaches
|
||||
|
||||
**Handle Feedback Loops:**
|
||||
- When a tool returns a suggested alternative (like a refined goal), present it clearly and ask the user for confirmation before proceeding
|
||||
- When clarifying questions are answered, immediately re-call the tool with the accumulated context
|
||||
- Don't ask redundant questions if the user has already provided context in the conversation
|
||||
|
||||
## CRITICAL REMINDER
|
||||
|
||||
You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation."""
|
||||
@@ -1232,10 +1225,23 @@ async def _stream_chat_chunks(
|
||||
},
|
||||
)
|
||||
|
||||
# Execute all accumulated tool calls in parallel
|
||||
# Events are yielded as they arrive from each concurrent tool
|
||||
async for event in _execute_tool_calls_parallel(tool_calls, session):
|
||||
yield event
|
||||
# Yield all accumulated tool calls after the stream is complete
|
||||
# This ensures all tool call arguments have been fully received
|
||||
for idx, tool_call in enumerate(tool_calls):
|
||||
try:
|
||||
async for tc in _yield_tool_call(tool_calls, idx, session):
|
||||
yield tc
|
||||
except (orjson.JSONDecodeError, KeyError, TypeError) as e:
|
||||
logger.error(
|
||||
f"Failed to parse tool call {idx}: {e}",
|
||||
exc_info=True,
|
||||
extra={"tool_call": tool_call},
|
||||
)
|
||||
yield StreamError(
|
||||
errorText=f"Invalid tool call arguments for tool {tool_call.get('function', {}).get('name', 'unknown')}: {e}",
|
||||
)
|
||||
# Re-raise to trigger retry logic in the parent function
|
||||
raise
|
||||
|
||||
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||
logger.info(
|
||||
@@ -1313,91 +1319,10 @@ async def _stream_chat_chunks(
|
||||
return
|
||||
|
||||
|
||||
async def _with_optional_lock(
|
||||
lock: asyncio.Lock | None,
|
||||
coro_fn: Any,
|
||||
) -> Any:
|
||||
"""Run *coro_fn()* under *lock* when provided, otherwise run directly."""
|
||||
if lock:
|
||||
async with lock:
|
||||
return await coro_fn()
|
||||
return await coro_fn()
|
||||
|
||||
|
||||
async def _execute_tool_calls_parallel(
|
||||
tool_calls: list[dict[str, Any]],
|
||||
session: ChatSession,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Execute all tool calls concurrently, yielding stream events as they arrive.
|
||||
|
||||
Each tool runs as an ``asyncio.Task``, pushing events into a shared queue.
|
||||
A ``session_lock`` serialises session-state mutations (long-running tool
|
||||
bookkeeping, ``run_agent`` counters).
|
||||
"""
|
||||
queue: asyncio.Queue[StreamBaseResponse | None] = asyncio.Queue()
|
||||
session_lock = asyncio.Lock()
|
||||
n_tools = len(tool_calls)
|
||||
retryable_errors: list[Exception] = []
|
||||
|
||||
async def _run_tool(idx: int) -> None:
|
||||
tool_name = tool_calls[idx].get("function", {}).get("name", "unknown")
|
||||
tool_call_id = tool_calls[idx].get("id", f"unknown_{idx}")
|
||||
try:
|
||||
async for event in _yield_tool_call(tool_calls, idx, session, session_lock):
|
||||
await queue.put(event)
|
||||
except (orjson.JSONDecodeError, KeyError, TypeError) as e:
|
||||
logger.error(
|
||||
f"Failed to parse tool call {idx} ({tool_name}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
retryable_errors.append(e)
|
||||
except Exception as e:
|
||||
# Infrastructure / setup errors — emit an error output so the
|
||||
# client always sees a terminal event and doesn't hang.
|
||||
logger.error(f"Tool call {idx} ({tool_name}) failed: {e}", exc_info=True)
|
||||
await queue.put(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=ErrorResponse(
|
||||
message=f"Tool execution failed: {e!s}",
|
||||
error=type(e).__name__,
|
||||
session_id=session.session_id,
|
||||
).model_dump_json(),
|
||||
success=False,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
await queue.put(None) # sentinel
|
||||
|
||||
tasks = [asyncio.create_task(_run_tool(idx)) for idx in range(n_tools)]
|
||||
try:
|
||||
finished = 0
|
||||
while finished < n_tools:
|
||||
event = await queue.get()
|
||||
if event is None:
|
||||
finished += 1
|
||||
else:
|
||||
yield event
|
||||
if retryable_errors:
|
||||
if len(retryable_errors) > 1:
|
||||
logger.warning(
|
||||
f"{len(retryable_errors)} tool calls had retryable errors; "
|
||||
f"re-raising first to trigger retry"
|
||||
)
|
||||
raise retryable_errors[0]
|
||||
finally:
|
||||
for t in tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
async def _yield_tool_call(
|
||||
tool_calls: list[dict[str, Any]],
|
||||
yield_idx: int,
|
||||
session: ChatSession,
|
||||
session_lock: asyncio.Lock | None = None,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""
|
||||
Yield a tool call and its execution result.
|
||||
@@ -1495,7 +1420,8 @@ async def _yield_tool_call(
|
||||
"check back in a few minutes."
|
||||
)
|
||||
|
||||
# Track appended message for rollback on failure
|
||||
# Track appended messages for rollback on failure
|
||||
assistant_message: ChatMessage | None = None
|
||||
pending_message: ChatMessage | None = None
|
||||
|
||||
# Wrap session save and task creation in try-except to release lock on failure
|
||||
@@ -1510,24 +1436,22 @@ async def _yield_tool_call(
|
||||
operation_id=operation_id,
|
||||
)
|
||||
|
||||
# Attach tool_call and save pending result — lock serialises
|
||||
# concurrent session mutations during parallel execution.
|
||||
async def _save_pending() -> None:
|
||||
nonlocal pending_message
|
||||
session.add_tool_call_to_current_turn(tool_calls[yield_idx])
|
||||
pending_message = ChatMessage(
|
||||
role="tool",
|
||||
content=OperationPendingResponse(
|
||||
message=pending_msg,
|
||||
operation_id=operation_id,
|
||||
tool_name=tool_name,
|
||||
).model_dump_json(),
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
session.messages.append(pending_message)
|
||||
await upsert_chat_session(session)
|
||||
# Attach the tool_call to the current turn's assistant message
|
||||
# (or create one if this is a tool-only response with no text).
|
||||
session.add_tool_call_to_current_turn(tool_calls[yield_idx])
|
||||
|
||||
await _with_optional_lock(session_lock, _save_pending)
|
||||
# Then save pending tool result
|
||||
pending_message = ChatMessage(
|
||||
role="tool",
|
||||
content=OperationPendingResponse(
|
||||
message=pending_msg,
|
||||
operation_id=operation_id,
|
||||
tool_name=tool_name,
|
||||
).model_dump_json(),
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
session.messages.append(pending_message)
|
||||
await upsert_chat_session(session)
|
||||
logger.info(
|
||||
f"Saved pending operation {operation_id} (task_id={task_id}) "
|
||||
f"for tool {tool_name} in session {session.session_id}"
|
||||
@@ -1551,13 +1475,19 @@ async def _yield_tool_call(
|
||||
# Associate the asyncio task with the stream registry task
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
except Exception as e:
|
||||
# Roll back appended messages — use identity-based removal so
|
||||
# it works even when other parallel tools have appended after us.
|
||||
async def _rollback() -> None:
|
||||
if pending_message and pending_message in session.messages:
|
||||
session.messages.remove(pending_message)
|
||||
|
||||
await _with_optional_lock(session_lock, _rollback)
|
||||
# Roll back appended messages to prevent data corruption on subsequent saves
|
||||
if (
|
||||
pending_message
|
||||
and session.messages
|
||||
and session.messages[-1] == pending_message
|
||||
):
|
||||
session.messages.pop()
|
||||
if (
|
||||
assistant_message
|
||||
and session.messages
|
||||
and session.messages[-1] == assistant_message
|
||||
):
|
||||
session.messages.pop()
|
||||
|
||||
# Release the Redis lock since the background task won't be spawned
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
|
||||
@@ -13,6 +13,15 @@ from .bash_exec import BashExecTool
|
||||
from .check_operation_status import CheckOperationStatusTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
from .e2b_file_tools import (
|
||||
E2BEditTool,
|
||||
E2BGlobTool,
|
||||
E2BGrepTool,
|
||||
E2BReadTool,
|
||||
E2BWriteTool,
|
||||
LoadFromWorkspaceTool,
|
||||
SaveToWorkspaceTool,
|
||||
)
|
||||
from .edit_agent import EditAgentTool
|
||||
from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool
|
||||
from .find_agent import FindAgentTool
|
||||
@@ -63,6 +72,14 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"read_workspace_file": ReadWorkspaceFileTool(),
|
||||
"write_workspace_file": WriteWorkspaceFileTool(),
|
||||
"delete_workspace_file": DeleteWorkspaceFileTool(),
|
||||
# E2B sandbox file tools (active when COPILOT_E2B feature flag is enabled)
|
||||
"read_file": E2BReadTool(),
|
||||
"write_file": E2BWriteTool(),
|
||||
"edit_file": E2BEditTool(),
|
||||
"glob_files": E2BGlobTool(),
|
||||
"grep_files": E2BGrepTool(),
|
||||
"save_to_workspace": SaveToWorkspaceTool(),
|
||||
"load_from_workspace": LoadFromWorkspaceTool(),
|
||||
}
|
||||
|
||||
# Export individual tool instances for backwards compatibility
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
"""Bash execution tool — run shell commands in a bubblewrap sandbox.
|
||||
"""Bash execution tool — run shell commands in a sandbox.
|
||||
|
||||
Supports two backends:
|
||||
- **e2b** (preferred): VM-level isolation with network access, enabled via
|
||||
the COPILOT_E2B feature flag.
|
||||
- **bubblewrap** (fallback): kernel-level isolation, no network, Linux-only.
|
||||
|
||||
Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.).
|
||||
Safety comes from OS-level isolation (bubblewrap): only system dirs visible
|
||||
read-only, writable workspace only, clean env, no network.
|
||||
|
||||
Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not
|
||||
available (e.g. macOS development).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import shlex
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
@@ -19,6 +20,8 @@ from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SANDBOX_HOME = "/home/user"
|
||||
|
||||
|
||||
class BashExecTool(BaseTool):
|
||||
"""Execute Bash commands in a bubblewrap sandbox."""
|
||||
@@ -29,6 +32,18 @@ class BashExecTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
if _is_e2b_available():
|
||||
return (
|
||||
"Execute a Bash command or script in an e2b sandbox (microVM). "
|
||||
"Full Bash scripting is supported (loops, conditionals, pipes, "
|
||||
"functions, etc.). "
|
||||
"The sandbox shares the same filesystem as the read_file/write_file "
|
||||
"tools — files created by any tool are accessible to all others. "
|
||||
"Network access IS available (pip install, curl, etc.). "
|
||||
"Working directory is /home/user/. "
|
||||
"Execution is killed after the timeout (default 30s, max 120s). "
|
||||
"Returns stdout and stderr."
|
||||
)
|
||||
if not has_full_sandbox():
|
||||
return (
|
||||
"Bash execution is DISABLED — bubblewrap sandbox is not "
|
||||
@@ -85,13 +100,6 @@ class BashExecTool(BaseTool):
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not has_full_sandbox():
|
||||
return ErrorResponse(
|
||||
message="bash_exec requires bubblewrap sandbox (Linux only).",
|
||||
error="sandbox_unavailable",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
command: str = (kwargs.get("command") or "").strip()
|
||||
timeout: int = kwargs.get("timeout", 30)
|
||||
|
||||
@@ -102,6 +110,20 @@ class BashExecTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# --- E2B path ---
|
||||
if _is_e2b_available():
|
||||
return await self._execute_e2b(
|
||||
command, timeout, session, user_id, session_id
|
||||
)
|
||||
|
||||
# --- Bubblewrap fallback ---
|
||||
if not has_full_sandbox():
|
||||
return ErrorResponse(
|
||||
message="bash_exec requires bubblewrap sandbox (Linux only).",
|
||||
error="sandbox_unavailable",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
workspace = get_workspace_dir(session_id or "default")
|
||||
|
||||
stdout, stderr, exit_code, timed_out = await run_sandboxed(
|
||||
@@ -122,3 +144,72 @@ class BashExecTool(BaseTool):
|
||||
timed_out=timed_out,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _execute_e2b(
|
||||
self,
|
||||
command: str,
|
||||
timeout: int,
|
||||
session: ChatSession,
|
||||
user_id: str | None,
|
||||
session_id: str | None,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute command in e2b sandbox."""
|
||||
try:
|
||||
from backend.copilot.sdk.tool_adapter import get_sandbox_manager
|
||||
|
||||
manager = get_sandbox_manager()
|
||||
if manager is None:
|
||||
return ErrorResponse(
|
||||
message="E2B sandbox manager not available.",
|
||||
error="sandbox_unavailable",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
sandbox = await manager.get_or_create(
|
||||
session_id or "default", user_id or "anonymous"
|
||||
)
|
||||
result = await sandbox.commands.run(
|
||||
f"bash -c {shlex.quote(command)}",
|
||||
cwd=_SANDBOX_HOME,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return BashExecResponse(
|
||||
message=f"Command executed (exit {result.exit_code})",
|
||||
stdout=result.stdout,
|
||||
stderr=result.stderr,
|
||||
exit_code=result.exit_code,
|
||||
timed_out=False,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
if "timeout" in error_str.lower():
|
||||
return BashExecResponse(
|
||||
message="Execution timed out",
|
||||
stdout="",
|
||||
stderr=f"Execution timed out after {timeout}s",
|
||||
exit_code=-1,
|
||||
timed_out=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=f"E2B execution failed: {e}",
|
||||
error=error_str,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Module-level helpers (placed after classes that call them)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _is_e2b_available() -> bool:
|
||||
"""Check if e2b sandbox is available via execution context."""
|
||||
try:
|
||||
from backend.copilot.sdk.tool_adapter import get_sandbox_manager
|
||||
|
||||
return get_sandbox_manager() is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -22,7 +22,6 @@ from .models import (
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
SuggestedGoalResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
@@ -187,28 +186,26 @@ class CreateAgentTool(BaseTool):
|
||||
if decomposition_result.get("type") == "unachievable_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
reason = decomposition_result.get("reason", "")
|
||||
return SuggestedGoalResponse(
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"This goal cannot be accomplished with the available blocks. {reason}"
|
||||
f"This goal cannot be accomplished with the available blocks. "
|
||||
f"{reason} "
|
||||
f"Suggestion: {suggested}"
|
||||
),
|
||||
suggested_goal=suggested,
|
||||
reason=reason,
|
||||
original_goal=description,
|
||||
goal_type="unachievable",
|
||||
error="unachievable_goal",
|
||||
details={"suggested_goal": suggested, "reason": reason},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result.get("type") == "vague_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
reason = decomposition_result.get(
|
||||
"reason", "The goal needs more specific details"
|
||||
)
|
||||
return SuggestedGoalResponse(
|
||||
message="The goal is too vague to create a specific workflow.",
|
||||
suggested_goal=suggested,
|
||||
reason=reason,
|
||||
original_goal=description,
|
||||
goal_type="vague",
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The goal is too vague to create a specific workflow. "
|
||||
f"Suggestion: {suggested}"
|
||||
),
|
||||
error="vague_goal",
|
||||
details={"suggested_goal": suggested},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
"""Tests for CreateAgentTool response types."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools.create_agent import CreateAgentTool
|
||||
from backend.copilot.tools.models import (
|
||||
ClarificationNeededResponse,
|
||||
ErrorResponse,
|
||||
SuggestedGoalResponse,
|
||||
)
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-create-agent"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return CreateAgentTool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
return make_session(_TEST_USER_ID)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_description_returns_error(tool, session):
|
||||
"""Missing description returns ErrorResponse."""
|
||||
result = await tool._execute(user_id=_TEST_USER_ID, session=session, description="")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "Missing description parameter"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vague_goal_returns_suggested_goal_response(tool, session):
|
||||
"""vague_goal decomposition result returns SuggestedGoalResponse, not ErrorResponse."""
|
||||
vague_result = {
|
||||
"type": "vague_goal",
|
||||
"suggested_goal": "Monitor Twitter mentions for a specific keyword and send a daily digest email",
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
return_value=vague_result,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
description="monitor social media",
|
||||
)
|
||||
|
||||
assert isinstance(result, SuggestedGoalResponse)
|
||||
assert result.goal_type == "vague"
|
||||
assert result.suggested_goal == vague_result["suggested_goal"]
|
||||
assert result.original_goal == "monitor social media"
|
||||
assert result.reason == "The goal needs more specific details"
|
||||
assert not isinstance(result, ErrorResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unachievable_goal_returns_suggested_goal_response(tool, session):
|
||||
"""unachievable_goal decomposition result returns SuggestedGoalResponse, not ErrorResponse."""
|
||||
unachievable_result = {
|
||||
"type": "unachievable_goal",
|
||||
"suggested_goal": "Summarize the latest news articles on a topic and send them by email",
|
||||
"reason": "There are no blocks for mind-reading.",
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
return_value=unachievable_result,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
description="read my mind",
|
||||
)
|
||||
|
||||
assert isinstance(result, SuggestedGoalResponse)
|
||||
assert result.goal_type == "unachievable"
|
||||
assert result.suggested_goal == unachievable_result["suggested_goal"]
|
||||
assert result.original_goal == "read my mind"
|
||||
assert result.reason == unachievable_result["reason"]
|
||||
assert not isinstance(result, ErrorResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clarifying_questions_returns_clarification_needed_response(
|
||||
tool, session
|
||||
):
|
||||
"""clarifying_questions decomposition result returns ClarificationNeededResponse."""
|
||||
clarifying_result = {
|
||||
"type": "clarifying_questions",
|
||||
"questions": [
|
||||
{
|
||||
"question": "What platform should be monitored?",
|
||||
"keyword": "platform",
|
||||
"example": "Twitter, Reddit",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
return_value=clarifying_result,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
description="monitor social media and alert me",
|
||||
)
|
||||
|
||||
assert isinstance(result, ClarificationNeededResponse)
|
||||
assert len(result.questions) == 1
|
||||
assert result.questions[0].keyword == "platform"
|
||||
703
autogpt_platform/backend/backend/copilot/tools/e2b_file_tools.py
Normal file
703
autogpt_platform/backend/backend/copilot/tools/e2b_file_tools.py
Normal file
@@ -0,0 +1,703 @@
|
||||
"""E2B file tools — MCP tools that proxy filesystem operations to the e2b sandbox.
|
||||
|
||||
These replace the SDK built-in Read/Write/Edit/Glob/Grep tools when e2b is
|
||||
enabled, ensuring all file operations go through the sandbox VM.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import posixpath
|
||||
import shlex
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import BashExecResponse, ErrorResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SANDBOX_HOME = "/home/user"
|
||||
|
||||
|
||||
class E2BReadTool(BaseTool):
|
||||
"""Read a file from the e2b sandbox filesystem."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "read_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read a file from the sandbox filesystem. "
|
||||
"The sandbox is the shared working environment — files created by "
|
||||
"any tool (bash_exec, write_file, etc.) are accessible here. "
|
||||
"Returns the file content as text. "
|
||||
"Use offset and limit for large files."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Path to the file to read (relative to /home/user/ "
|
||||
"or absolute within /home/user/)."
|
||||
),
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Line number to start reading from (0-indexed). Default: 0"
|
||||
),
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to read. Default: 2000",
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
path = kwargs.get("path", "")
|
||||
offset = kwargs.get("offset", 0)
|
||||
limit = kwargs.get("limit", 2000)
|
||||
|
||||
sandbox = await _get_sandbox(session)
|
||||
if sandbox is None:
|
||||
return _sandbox_unavailable(session)
|
||||
|
||||
resolved = _resolve_path(path)
|
||||
if resolved is None:
|
||||
return _path_error(path, session)
|
||||
|
||||
try:
|
||||
content = await sandbox.files.read(resolved)
|
||||
lines = content.splitlines(keepends=True)
|
||||
selected = lines[offset : offset + limit]
|
||||
text = "".join(selected)
|
||||
return BashExecResponse(
|
||||
message=f"Read {len(selected)} lines from {resolved}",
|
||||
stdout=text,
|
||||
stderr="",
|
||||
exit_code=0,
|
||||
timed_out=False,
|
||||
session_id=session.session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read {resolved}: {e}",
|
||||
error=str(e),
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
|
||||
class E2BWriteTool(BaseTool):
|
||||
"""Write a file to the e2b sandbox filesystem."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "write_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Write or create a file in the sandbox filesystem. "
|
||||
"This is the shared working environment — files are accessible "
|
||||
"to bash_exec and other tools. "
|
||||
"Creates parent directories automatically."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Path for the file (relative to /home/user/ "
|
||||
"or absolute within /home/user/)."
|
||||
),
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file.",
|
||||
},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
path = kwargs.get("path", "")
|
||||
content = kwargs.get("content", "")
|
||||
|
||||
sandbox = await _get_sandbox(session)
|
||||
if sandbox is None:
|
||||
return _sandbox_unavailable(session)
|
||||
|
||||
resolved = _resolve_path(path)
|
||||
if resolved is None:
|
||||
return _path_error(path, session)
|
||||
|
||||
try:
|
||||
# Ensure parent directory exists
|
||||
parent = posixpath.dirname(resolved)
|
||||
if parent and parent != _SANDBOX_HOME:
|
||||
await sandbox.commands.run(f"mkdir -p {parent}", timeout=5)
|
||||
await sandbox.files.write(resolved, content)
|
||||
return BashExecResponse(
|
||||
message=f"Wrote {len(content)} bytes to {resolved}",
|
||||
stdout=f"Successfully wrote to {resolved}",
|
||||
stderr="",
|
||||
exit_code=0,
|
||||
timed_out=False,
|
||||
session_id=session.session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to write {resolved}: {e}",
|
||||
error=str(e),
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
|
||||
class E2BEditTool(BaseTool):
|
||||
"""Edit a file in the e2b sandbox using search/replace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "edit_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit a file in the sandbox by replacing exact text. "
|
||||
"Provide old_text (the exact text to find) and new_text "
|
||||
"(what to replace it with). The old_text must match exactly."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Path to the file (relative to /home/user/ "
|
||||
"or absolute within /home/user/)."
|
||||
),
|
||||
},
|
||||
"old_text": {
|
||||
"type": "string",
|
||||
"description": "Exact text to find in the file.",
|
||||
},
|
||||
"new_text": {
|
||||
"type": "string",
|
||||
"description": "Text to replace old_text with.",
|
||||
},
|
||||
},
|
||||
"required": ["path", "old_text", "new_text"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
path = kwargs.get("path", "")
|
||||
old_text = kwargs.get("old_text", "")
|
||||
new_text = kwargs.get("new_text", "")
|
||||
|
||||
sandbox = await _get_sandbox(session)
|
||||
if sandbox is None:
|
||||
return _sandbox_unavailable(session)
|
||||
|
||||
resolved = _resolve_path(path)
|
||||
if resolved is None:
|
||||
return _path_error(path, session)
|
||||
|
||||
try:
|
||||
content = await sandbox.files.read(resolved)
|
||||
occurrences = content.count(old_text)
|
||||
if occurrences == 0:
|
||||
return ErrorResponse(
|
||||
message=f"old_text not found in {resolved}",
|
||||
error="text_not_found",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
if occurrences > 1:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"old_text found {occurrences} times in {resolved}. "
|
||||
"Please provide more context to make the match unique."
|
||||
),
|
||||
error="ambiguous_match",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
new_content = content.replace(old_text, new_text, 1)
|
||||
await sandbox.files.write(resolved, new_content)
|
||||
return BashExecResponse(
|
||||
message=f"Edited {resolved}",
|
||||
stdout=f"Successfully edited {resolved}",
|
||||
stderr="",
|
||||
exit_code=0,
|
||||
timed_out=False,
|
||||
session_id=session.session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to edit {resolved}: {e}",
|
||||
error=str(e),
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
|
||||
class E2BGlobTool(BaseTool):
|
||||
"""List files matching a pattern in the e2b sandbox."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "glob_files"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List files in the sandbox matching a glob pattern. "
|
||||
"Uses find under the hood. Default directory is /home/user/."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Glob pattern to match (e.g., '*.py', '**/*.json')."
|
||||
),
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": ("Directory to search in (default: /home/user/)."),
|
||||
},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
pattern = kwargs.get("pattern", "*")
|
||||
path = kwargs.get("path", _SANDBOX_HOME)
|
||||
|
||||
sandbox = await _get_sandbox(session)
|
||||
if sandbox is None:
|
||||
return _sandbox_unavailable(session)
|
||||
|
||||
resolved = _resolve_path(path)
|
||||
if resolved is None:
|
||||
return _path_error(path, session)
|
||||
|
||||
try:
|
||||
result = await sandbox.commands.run(
|
||||
f"find {resolved} -name {shlex.quote(pattern)} -type f 2>/dev/null",
|
||||
timeout=15,
|
||||
)
|
||||
return BashExecResponse(
|
||||
message="Glob results",
|
||||
stdout=result.stdout,
|
||||
stderr=result.stderr,
|
||||
exit_code=result.exit_code,
|
||||
timed_out=False,
|
||||
session_id=session.session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to glob: {e}",
|
||||
error=str(e),
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
|
||||
class E2BGrepTool(BaseTool):
|
||||
"""Search file contents in the e2b sandbox."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "grep_files"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for a pattern in files within the sandbox. "
|
||||
"Uses grep -rn under the hood. Returns matching lines with "
|
||||
"file paths and line numbers."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Search pattern (regex supported).",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": ("Directory to search in (default: /home/user/)."),
|
||||
},
|
||||
"include": {
|
||||
"type": "string",
|
||||
"description": "File glob to include (e.g., '*.py').",
|
||||
},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
pattern = kwargs.get("pattern", "")
|
||||
path = kwargs.get("path", _SANDBOX_HOME)
|
||||
include = kwargs.get("include", "")
|
||||
|
||||
sandbox = await _get_sandbox(session)
|
||||
if sandbox is None:
|
||||
return _sandbox_unavailable(session)
|
||||
|
||||
resolved = _resolve_path(path)
|
||||
if resolved is None:
|
||||
return _path_error(path, session)
|
||||
|
||||
include_flag = f" --include={shlex.quote(include)}" if include else ""
|
||||
try:
|
||||
result = await sandbox.commands.run(
|
||||
f"grep -rn{include_flag} {shlex.quote(pattern)} {resolved} 2>/dev/null",
|
||||
timeout=15,
|
||||
)
|
||||
return BashExecResponse(
|
||||
message="Grep results",
|
||||
stdout=result.stdout,
|
||||
stderr=result.stderr,
|
||||
exit_code=result.exit_code,
|
||||
timed_out=False,
|
||||
session_id=session.session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to grep: {e}",
|
||||
error=str(e),
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
|
||||
class SaveToWorkspaceTool(BaseTool):
|
||||
"""Copy a file from e2b sandbox to the persistent GCS workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "save_to_workspace"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Save a file from the sandbox to the persistent workspace "
|
||||
"(cloud storage). Files saved to workspace survive across sessions. "
|
||||
"Provide the sandbox file path and optional workspace path."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sandbox_path": {
|
||||
"type": "string",
|
||||
"description": "Path of the file in the sandbox to save.",
|
||||
},
|
||||
"workspace_path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Path in the workspace to save to "
|
||||
"(defaults to the sandbox filename)."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["sandbox_path"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
sandbox_path = kwargs.get("sandbox_path", "")
|
||||
workspace_path = kwargs.get("workspace_path", "")
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
sandbox = await _get_sandbox(session)
|
||||
if sandbox is None:
|
||||
return _sandbox_unavailable(session)
|
||||
|
||||
resolved = _resolve_path(sandbox_path)
|
||||
if resolved is None:
|
||||
return _path_error(sandbox_path, session)
|
||||
|
||||
try:
|
||||
content_bytes = await sandbox.files.read(resolved, format="bytes")
|
||||
|
||||
# Determine workspace path
|
||||
filename = resolved.rsplit("/", 1)[-1]
|
||||
wp = workspace_path or f"/{filename}"
|
||||
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
manager = WorkspaceManager(user_id, workspace.id, session.session_id)
|
||||
file_record = await manager.write_file(
|
||||
content=content_bytes,
|
||||
filename=filename,
|
||||
path=wp,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
return BashExecResponse(
|
||||
message=f"Saved {resolved} to workspace at {file_record.path}",
|
||||
stdout=(
|
||||
f"Saved to workspace: {file_record.path} "
|
||||
f"({file_record.size_bytes} bytes)"
|
||||
),
|
||||
stderr="",
|
||||
exit_code=0,
|
||||
timed_out=False,
|
||||
session_id=session.session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save to workspace: {e}",
|
||||
error=str(e),
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
|
||||
class LoadFromWorkspaceTool(BaseTool):
|
||||
"""Copy a file from the persistent GCS workspace into the e2b sandbox."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "load_from_workspace"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Load a file from the persistent workspace (cloud storage) into "
|
||||
"the sandbox. Use this to bring workspace files into the sandbox "
|
||||
"for editing or processing."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"workspace_path": {
|
||||
"type": "string",
|
||||
"description": ("Path of the file in the workspace to load."),
|
||||
},
|
||||
"sandbox_path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Path in the sandbox to write to "
|
||||
"(defaults to /home/user/<filename>)."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["workspace_path"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
workspace_path = kwargs.get("workspace_path", "")
|
||||
sandbox_path = kwargs.get("sandbox_path", "")
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
sandbox = await _get_sandbox(session)
|
||||
if sandbox is None:
|
||||
return _sandbox_unavailable(session)
|
||||
|
||||
try:
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
manager = WorkspaceManager(user_id, workspace.id, session.session_id)
|
||||
file_info = await manager.get_file_info_by_path(workspace_path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found in workspace: {workspace_path}",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
content = await manager.read_file_by_id(file_info.id)
|
||||
|
||||
# Determine sandbox path
|
||||
filename = workspace_path.rsplit("/", 1)[-1]
|
||||
target = sandbox_path or f"{_SANDBOX_HOME}/{filename}"
|
||||
resolved = _resolve_path(target)
|
||||
if resolved is None:
|
||||
return _path_error(target, session)
|
||||
|
||||
# Ensure parent directory exists
|
||||
parent = posixpath.dirname(resolved)
|
||||
if parent and parent != _SANDBOX_HOME:
|
||||
await sandbox.commands.run(f"mkdir -p {parent}", timeout=5)
|
||||
await sandbox.files.write(resolved, content)
|
||||
|
||||
return BashExecResponse(
|
||||
message=f"Loaded {workspace_path} into sandbox at {resolved}",
|
||||
stdout=(f"Loaded from workspace: {resolved} ({len(content)} bytes)"),
|
||||
stderr="",
|
||||
exit_code=0,
|
||||
timed_out=False,
|
||||
session_id=session.session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to load from workspace: {e}",
|
||||
error=str(e),
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Module-level helpers (placed after functions that call them)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_path(path: str) -> str | None:
|
||||
"""Resolve a path to an absolute path within /home/user/.
|
||||
|
||||
Returns None if the path escapes the sandbox home.
|
||||
"""
|
||||
if not path:
|
||||
return None
|
||||
|
||||
# Handle relative paths
|
||||
if not path.startswith("/"):
|
||||
path = f"{_SANDBOX_HOME}/{path}"
|
||||
|
||||
# Normalize to prevent traversal
|
||||
resolved = posixpath.normpath(path)
|
||||
|
||||
if not resolved.startswith(_SANDBOX_HOME):
|
||||
return None
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
async def _get_sandbox(session: ChatSession) -> Any | None:
|
||||
"""Get the sandbox for the current session from the execution context."""
|
||||
try:
|
||||
from backend.copilot.sdk.tool_adapter import get_sandbox_manager
|
||||
|
||||
manager = get_sandbox_manager()
|
||||
if manager is None:
|
||||
return None
|
||||
user_id, _ = _get_user_from_context()
|
||||
return await manager.get_or_create(session.session_id, user_id or "anonymous")
|
||||
except Exception as e:
|
||||
logger.error(f"[E2B] Failed to get sandbox: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _get_user_from_context() -> tuple[str | None, Any]:
|
||||
"""Get user_id from execution context."""
|
||||
from backend.copilot.sdk.tool_adapter import get_execution_context
|
||||
|
||||
return get_execution_context()
|
||||
|
||||
|
||||
def _sandbox_unavailable(session: ChatSession) -> ErrorResponse:
|
||||
"""Return an error response for unavailable sandbox."""
|
||||
return ErrorResponse(
|
||||
message="E2B sandbox is not available. Try again or contact support.",
|
||||
error="sandbox_unavailable",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
|
||||
def _path_error(path: str, session: ChatSession) -> ErrorResponse:
|
||||
"""Return an error response for invalid paths."""
|
||||
return ErrorResponse(
|
||||
message=f"Invalid path: {path}. Paths must be within /home/user/.",
|
||||
error="invalid_path",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
215
autogpt_platform/backend/backend/copilot/tools/e2b_sandbox.py
Normal file
215
autogpt_platform/backend/backend/copilot/tools/e2b_sandbox.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""E2B sandbox manager for CoPilot sessions.
|
||||
|
||||
Manages e2b sandbox lifecycle: create, reuse via Redis, dispose with GCS sync.
|
||||
One sandbox per session, cached in-memory on the worker thread and stored in
|
||||
Redis for cross-pod reconnection.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_REDIS_KEY_PREFIX = "copilot:sandbox:"
|
||||
_SANDBOX_HOME = "/home/user"
|
||||
|
||||
|
||||
class CoPilotSandboxManager:
|
||||
"""Manages e2b sandbox lifecycle for CoPilot sessions.
|
||||
|
||||
Each session gets a single sandbox. The sandbox_id is stored in Redis
|
||||
so another pod can reconnect to it if the original pod dies.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._sandboxes: dict[str, Any] = {} # session_id -> AsyncSandbox
|
||||
self._last_activity: dict[str, float] = {} # session_id -> timestamp
|
||||
self._cleanup_task: asyncio.Task[None] | None = None
|
||||
config = Config()
|
||||
self._timeout: int = config.copilot_sandbox_timeout
|
||||
self._template: str = config.copilot_sandbox_template
|
||||
self._api_key: str = config.e2b_api_key
|
||||
|
||||
async def get_or_create(self, session_id: str, user_id: str) -> Any:
|
||||
"""Get existing sandbox or create a new one for this session.
|
||||
|
||||
Args:
|
||||
session_id: CoPilot chat session ID.
|
||||
user_id: User ID for workspace sync.
|
||||
|
||||
Returns:
|
||||
An e2b AsyncSandbox instance.
|
||||
"""
|
||||
self._last_activity[session_id] = time.monotonic()
|
||||
|
||||
# 1. Check in-memory cache
|
||||
if session_id in self._sandboxes:
|
||||
sandbox = self._sandboxes[session_id]
|
||||
if await _is_sandbox_alive(sandbox):
|
||||
return sandbox
|
||||
# Sandbox died — clean up stale reference
|
||||
del self._sandboxes[session_id]
|
||||
|
||||
# 2. Check Redis for sandbox_id (cross-pod reconnection)
|
||||
sandbox = await self._try_reconnect_from_redis(session_id)
|
||||
if sandbox is not None:
|
||||
self._sandboxes[session_id] = sandbox
|
||||
return sandbox
|
||||
|
||||
# 3. Create new sandbox
|
||||
sandbox = await self._create_sandbox(session_id, user_id)
|
||||
self._sandboxes[session_id] = sandbox
|
||||
await _store_sandbox_id_in_redis(session_id, sandbox.sandbox_id)
|
||||
|
||||
# 4. Start cleanup task if not running
|
||||
self._ensure_cleanup_task()
|
||||
|
||||
return sandbox
|
||||
|
||||
async def dispose(self, session_id: str) -> None:
|
||||
"""Persist workspace files to GCS, then kill sandbox.
|
||||
|
||||
Args:
|
||||
session_id: CoPilot chat session ID.
|
||||
"""
|
||||
sandbox = self._sandboxes.pop(session_id, None)
|
||||
self._last_activity.pop(session_id, None)
|
||||
|
||||
if sandbox is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await sandbox.kill()
|
||||
except Exception as e:
|
||||
logger.warning(f"[E2B] Failed to kill sandbox for {session_id}: {e}")
|
||||
|
||||
await _remove_sandbox_id_from_redis(session_id)
|
||||
logger.info(f"[E2B] Disposed sandbox for session {session_id}")
|
||||
|
||||
async def dispose_all(self) -> None:
|
||||
"""Dispose all sandboxes (called on processor shutdown)."""
|
||||
session_ids = list(self._sandboxes.keys())
|
||||
for sid in session_ids:
|
||||
await self.dispose(sid)
|
||||
if self._cleanup_task and not self._cleanup_task.done():
|
||||
self._cleanup_task.cancel()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _create_sandbox(self, session_id: str, user_id: str) -> Any:
|
||||
"""Create a new e2b sandbox."""
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
kwargs: dict[str, Any] = {"api_key": self._api_key}
|
||||
if self._template:
|
||||
kwargs["template"] = self._template
|
||||
if self._timeout:
|
||||
kwargs["timeout"] = self._timeout
|
||||
|
||||
sandbox = await AsyncSandbox.create(**kwargs)
|
||||
logger.info(
|
||||
f"[E2B] Created sandbox {sandbox.sandbox_id} for session={session_id}, "
|
||||
f"user={user_id}"
|
||||
)
|
||||
return sandbox
|
||||
|
||||
async def _try_reconnect_from_redis(self, session_id: str) -> Any | None:
|
||||
"""Attempt to reconnect to a sandbox stored in Redis."""
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
sandbox_id = await _load_sandbox_id_from_redis(session_id)
|
||||
if not sandbox_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
sandbox = await AsyncSandbox.connect(
|
||||
sandbox_id=sandbox_id, api_key=self._api_key
|
||||
)
|
||||
logger.info(
|
||||
f"[E2B] Reconnected to sandbox {sandbox_id} for session={session_id}"
|
||||
)
|
||||
return sandbox
|
||||
except Exception as e:
|
||||
logger.warning(f"[E2B] Failed to reconnect to sandbox {sandbox_id}: {e}")
|
||||
await _remove_sandbox_id_from_redis(session_id)
|
||||
return None
|
||||
|
||||
def _ensure_cleanup_task(self) -> None:
|
||||
"""Start the idle cleanup background task if not already running."""
|
||||
if self._cleanup_task is None or self._cleanup_task.done():
|
||||
self._cleanup_task = asyncio.ensure_future(self._idle_cleanup_loop())
|
||||
|
||||
async def _idle_cleanup_loop(self) -> None:
|
||||
"""Periodically check for idle sandboxes and dispose them."""
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
if not self._sandboxes:
|
||||
continue
|
||||
now = time.monotonic()
|
||||
to_dispose: list[str] = []
|
||||
for sid, last in list(self._last_activity.items()):
|
||||
if now - last > self._timeout:
|
||||
to_dispose.append(sid)
|
||||
for sid in to_dispose:
|
||||
logger.info(f"[E2B] Disposing idle sandbox for session {sid}")
|
||||
await self.dispose(sid)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Module-level helpers (placed after classes that call them)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _is_sandbox_alive(sandbox: Any) -> bool:
|
||||
"""Check if an e2b sandbox is still running."""
|
||||
try:
|
||||
result = await sandbox.commands.run("echo ok", timeout=5)
|
||||
return result.exit_code == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def _store_sandbox_id_in_redis(session_id: str, sandbox_id: str) -> None:
|
||||
"""Store sandbox_id in Redis keyed by session_id."""
|
||||
try:
|
||||
from backend.data import redis as redis_client
|
||||
|
||||
redis = redis_client.get_redis()
|
||||
key = f"{_REDIS_KEY_PREFIX}{session_id}"
|
||||
config = Config()
|
||||
ttl = max(config.copilot_sandbox_timeout * 2, 3600) # At least 1h, 2x timeout
|
||||
await redis.set(key, sandbox_id, ex=ttl)
|
||||
except Exception as e:
|
||||
logger.warning(f"[E2B] Failed to store sandbox_id in Redis: {e}")
|
||||
|
||||
|
||||
async def _load_sandbox_id_from_redis(session_id: str) -> str | None:
|
||||
"""Load sandbox_id from Redis."""
|
||||
try:
|
||||
from backend.data import redis as redis_client
|
||||
|
||||
redis = redis_client.get_redis()
|
||||
key = f"{_REDIS_KEY_PREFIX}{session_id}"
|
||||
value = await redis.get(key)
|
||||
return value.decode() if isinstance(value, bytes) else value
|
||||
except Exception as e:
|
||||
logger.warning(f"[E2B] Failed to load sandbox_id from Redis: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _remove_sandbox_id_from_redis(session_id: str) -> None:
|
||||
"""Remove sandbox_id from Redis."""
|
||||
try:
|
||||
from backend.data import redis as redis_client
|
||||
|
||||
redis = redis_client.get_redis()
|
||||
key = f"{_REDIS_KEY_PREFIX}{session_id}"
|
||||
await redis.delete(key)
|
||||
except Exception as e:
|
||||
logger.warning(f"[E2B] Failed to remove sandbox_id from Redis: {e}")
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -50,8 +50,6 @@ class ResponseType(str, Enum):
|
||||
# Feature request types
|
||||
FEATURE_REQUEST_SEARCH = "feature_request_search"
|
||||
FEATURE_REQUEST_CREATED = "feature_request_created"
|
||||
# Goal refinement
|
||||
SUGGESTED_GOAL = "suggested_goal"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -298,22 +296,6 @@ class ClarificationNeededResponse(ToolResponseBase):
|
||||
questions: list[ClarifyingQuestion] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SuggestedGoalResponse(ToolResponseBase):
|
||||
"""Response when the goal needs refinement with a suggested alternative."""
|
||||
|
||||
type: ResponseType = ResponseType.SUGGESTED_GOAL
|
||||
suggested_goal: str = Field(description="The suggested alternative goal")
|
||||
reason: str = Field(
|
||||
default="", description="Why the original goal needs refinement"
|
||||
)
|
||||
original_goal: str = Field(
|
||||
default="", description="The user's original goal for context"
|
||||
)
|
||||
goal_type: Literal["vague", "unachievable"] = Field(
|
||||
default="vague", description="Type: 'vague' or 'unachievable'"
|
||||
)
|
||||
|
||||
|
||||
# Documentation search models
|
||||
class DocSearchResult(BaseModel):
|
||||
"""A single documentation search result."""
|
||||
|
||||
@@ -39,6 +39,7 @@ class Flag(str, Enum):
|
||||
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
|
||||
CHAT = "chat"
|
||||
COPILOT_SDK = "copilot-sdk"
|
||||
COPILOT_E2B = "copilot-e2b"
|
||||
|
||||
|
||||
def is_configured() -> bool:
|
||||
|
||||
@@ -665,6 +665,18 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
fal_api_key: str = Field(default="", description="FAL API key")
|
||||
exa_api_key: str = Field(default="", description="Exa API key")
|
||||
e2b_api_key: str = Field(default="", description="E2B API key")
|
||||
copilot_sandbox_timeout: int = Field(
|
||||
default=900,
|
||||
description="E2B sandbox idle timeout in seconds (default 15 min).",
|
||||
)
|
||||
copilot_sandbox_template: str = Field(
|
||||
default="",
|
||||
description="E2B sandbox template ID (empty = default template).",
|
||||
)
|
||||
copilot_use_e2b: bool = Field(
|
||||
default=False,
|
||||
description="Enable e2b sandbox for CoPilot (feature flag default).",
|
||||
)
|
||||
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
||||
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
||||
elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key")
|
||||
|
||||
@@ -30,16 +30,6 @@ pnpm format
|
||||
pnpm types
|
||||
```
|
||||
|
||||
### Pre-completion Checks (MANDATORY)
|
||||
|
||||
After making **any** code changes in the frontend, you MUST run the following commands **in order** before reporting work as done, creating commits, or opening PRs:
|
||||
|
||||
1. `pnpm format` — auto-fix formatting issues
|
||||
2. `pnpm lint` — check for lint errors; fix any that appear
|
||||
3. `pnpm types` — check for type errors; fix any that appear
|
||||
|
||||
Do NOT skip these steps. If any command reports errors, fix them and re-run until clean. Only then may you consider the task complete. If typing keeps failing, stop and ask the user.
|
||||
|
||||
### Code Style
|
||||
|
||||
- Fully capitalize acronyms in symbols, e.g. `graphID`, `useBackendAPI`
|
||||
@@ -84,4 +74,3 @@ See @CONTRIBUTING.md for complete patterns. Quick reference:
|
||||
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
|
||||
- Do not type hook returns, let Typescript infer as much as possible
|
||||
- Never type with `any` unless a variable/attribute can ACTUALLY be of any type
|
||||
- avoid index and barrel files
|
||||
|
||||
@@ -26,7 +26,6 @@ import {
|
||||
} from "./components/ClarificationQuestionsCard";
|
||||
import sparklesImg from "./components/MiniGame/assets/sparkles.png";
|
||||
import { MiniGame } from "./components/MiniGame/MiniGame";
|
||||
import { SuggestedGoalCard } from "./components/SuggestedGoalCard";
|
||||
import {
|
||||
AccordionIcon,
|
||||
formatMaybeJson,
|
||||
@@ -39,7 +38,6 @@ import {
|
||||
isOperationInProgressOutput,
|
||||
isOperationPendingOutput,
|
||||
isOperationStartedOutput,
|
||||
isSuggestedGoalOutput,
|
||||
ToolIcon,
|
||||
truncateText,
|
||||
type CreateAgentToolOutput,
|
||||
@@ -79,13 +77,6 @@ function getAccordionMeta(output: CreateAgentToolOutput) {
|
||||
expanded: true,
|
||||
};
|
||||
}
|
||||
if (isSuggestedGoalOutput(output)) {
|
||||
return {
|
||||
icon,
|
||||
title: "Goal needs refinement",
|
||||
expanded: true,
|
||||
};
|
||||
}
|
||||
if (
|
||||
isOperationStartedOutput(output) ||
|
||||
isOperationPendingOutput(output) ||
|
||||
@@ -134,13 +125,8 @@ export function CreateAgentTool({ part }: Props) {
|
||||
isAgentPreviewOutput(output) ||
|
||||
isAgentSavedOutput(output) ||
|
||||
isClarificationNeededOutput(output) ||
|
||||
isSuggestedGoalOutput(output) ||
|
||||
isErrorOutput(output));
|
||||
|
||||
function handleUseSuggestedGoal(goal: string) {
|
||||
onSend(`Please create an agent with this goal: ${goal}`);
|
||||
}
|
||||
|
||||
function handleClarificationAnswers(answers: Record<string, string>) {
|
||||
const questions =
|
||||
output && isClarificationNeededOutput(output)
|
||||
@@ -259,16 +245,6 @@ export function CreateAgentTool({ part }: Props) {
|
||||
/>
|
||||
)}
|
||||
|
||||
{isSuggestedGoalOutput(output) && (
|
||||
<SuggestedGoalCard
|
||||
message={output.message}
|
||||
suggestedGoal={output.suggested_goal}
|
||||
reason={output.reason}
|
||||
goalType={output.goal_type ?? "vague"}
|
||||
onUseSuggestedGoal={handleUseSuggestedGoal}
|
||||
/>
|
||||
)}
|
||||
|
||||
{isErrorOutput(output) && (
|
||||
<ContentGrid>
|
||||
<ContentMessage>{output.message}</ContentMessage>
|
||||
@@ -282,22 +258,6 @@ export function CreateAgentTool({ part }: Props) {
|
||||
{formatMaybeJson(output.details)}
|
||||
</ContentCodeBlock>
|
||||
)}
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={() => onSend("Please try creating the agent again.")}
|
||||
>
|
||||
Try again
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={() => onSend("Can you help me simplify this goal?")}
|
||||
>
|
||||
Simplify goal
|
||||
</Button>
|
||||
</div>
|
||||
</ContentGrid>
|
||||
)}
|
||||
</ToolAccordion>
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { ArrowRightIcon, LightbulbIcon } from "@phosphor-icons/react";
|
||||
|
||||
interface Props {
|
||||
message: string;
|
||||
suggestedGoal: string;
|
||||
reason?: string;
|
||||
goalType: string;
|
||||
onUseSuggestedGoal: (goal: string) => void;
|
||||
}
|
||||
|
||||
export function SuggestedGoalCard({
|
||||
message,
|
||||
suggestedGoal,
|
||||
reason,
|
||||
goalType,
|
||||
onUseSuggestedGoal,
|
||||
}: Props) {
|
||||
return (
|
||||
<div className="rounded-xl border border-amber-200 bg-amber-50/50 p-4">
|
||||
<div className="flex items-start gap-3">
|
||||
<LightbulbIcon
|
||||
size={20}
|
||||
weight="fill"
|
||||
className="mt-0.5 text-amber-600"
|
||||
/>
|
||||
<div className="flex-1 space-y-3">
|
||||
<div>
|
||||
<Text variant="body-medium" className="font-medium text-slate-900">
|
||||
{goalType === "unachievable"
|
||||
? "Goal cannot be accomplished"
|
||||
: "Goal needs more detail"}
|
||||
</Text>
|
||||
<Text variant="small" className="text-slate-600">
|
||||
{reason || message}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<div className="rounded-lg border border-amber-300 bg-white p-3">
|
||||
<Text variant="small" className="mb-1 font-semibold text-amber-800">
|
||||
Suggested alternative:
|
||||
</Text>
|
||||
<Text variant="body-medium" className="text-slate-900">
|
||||
{suggestedGoal}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<Button
|
||||
onClick={() => onUseSuggestedGoal(suggestedGoal)}
|
||||
variant="primary"
|
||||
>
|
||||
<span className="inline-flex items-center gap-1.5">
|
||||
Use this goal <ArrowRightIcon size={14} weight="bold" />
|
||||
</span>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import type { OperationInProgressResponse } from "@/app/api/__generated__/models
|
||||
import type { OperationPendingResponse } from "@/app/api/__generated__/models/operationPendingResponse";
|
||||
import type { OperationStartedResponse } from "@/app/api/__generated__/models/operationStartedResponse";
|
||||
import { ResponseType } from "@/app/api/__generated__/models/responseType";
|
||||
import type { SuggestedGoalResponse } from "@/app/api/__generated__/models/suggestedGoalResponse";
|
||||
import {
|
||||
PlusCircleIcon,
|
||||
PlusIcon,
|
||||
@@ -22,7 +21,6 @@ export type CreateAgentToolOutput =
|
||||
| AgentPreviewResponse
|
||||
| AgentSavedResponse
|
||||
| ClarificationNeededResponse
|
||||
| SuggestedGoalResponse
|
||||
| ErrorResponse;
|
||||
|
||||
function parseOutput(output: unknown): CreateAgentToolOutput | null {
|
||||
@@ -45,7 +43,6 @@ function parseOutput(output: unknown): CreateAgentToolOutput | null {
|
||||
type === ResponseType.agent_preview ||
|
||||
type === ResponseType.agent_saved ||
|
||||
type === ResponseType.clarification_needed ||
|
||||
type === ResponseType.suggested_goal ||
|
||||
type === ResponseType.error
|
||||
) {
|
||||
return output as CreateAgentToolOutput;
|
||||
@@ -58,7 +55,6 @@ function parseOutput(output: unknown): CreateAgentToolOutput | null {
|
||||
if ("agent_id" in output && "library_agent_id" in output)
|
||||
return output as AgentSavedResponse;
|
||||
if ("questions" in output) return output as ClarificationNeededResponse;
|
||||
if ("suggested_goal" in output) return output as SuggestedGoalResponse;
|
||||
if ("error" in output || "details" in output)
|
||||
return output as ErrorResponse;
|
||||
}
|
||||
@@ -118,14 +114,6 @@ export function isClarificationNeededOutput(
|
||||
);
|
||||
}
|
||||
|
||||
export function isSuggestedGoalOutput(
|
||||
output: CreateAgentToolOutput,
|
||||
): output is SuggestedGoalResponse {
|
||||
return (
|
||||
output.type === ResponseType.suggested_goal || "suggested_goal" in output
|
||||
);
|
||||
}
|
||||
|
||||
export function isErrorOutput(
|
||||
output: CreateAgentToolOutput,
|
||||
): output is ErrorResponse {
|
||||
@@ -151,7 +139,6 @@ export function getAnimationText(part: {
|
||||
if (isAgentSavedOutput(output)) return `Saved ${output.agent_name}`;
|
||||
if (isAgentPreviewOutput(output)) return `Preview "${output.agent_name}"`;
|
||||
if (isClarificationNeededOutput(output)) return "Needs clarification";
|
||||
if (isSuggestedGoalOutput(output)) return "Goal needs refinement";
|
||||
return "Error creating agent";
|
||||
}
|
||||
case "output-error":
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import {
|
||||
getGetV2ListSessionsQueryKey,
|
||||
postV2CancelSessionTask,
|
||||
useDeleteV2DeleteSession,
|
||||
useGetV2ListSessions,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
@@ -9,7 +8,6 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useChat } from "@ai-sdk/react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import type { UIMessage } from "ai";
|
||||
import { DefaultChatTransport } from "ai";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useChatSession } from "./useChatSession";
|
||||
@@ -17,24 +15,6 @@ import { useLongRunningToolPolling } from "./hooks/useLongRunningToolPolling";
|
||||
|
||||
const STREAM_START_TIMEOUT_MS = 12_000;
|
||||
|
||||
/** Mark any in-progress tool parts as completed/errored so spinners stop. */
|
||||
function resolveInProgressTools(
|
||||
messages: UIMessage[],
|
||||
outcome: "completed" | "cancelled",
|
||||
): UIMessage[] {
|
||||
return messages.map((msg) => ({
|
||||
...msg,
|
||||
parts: msg.parts.map((part) =>
|
||||
"state" in part &&
|
||||
(part.state === "input-streaming" || part.state === "input-available")
|
||||
? outcome === "cancelled"
|
||||
? { ...part, state: "output-error" as const, errorText: "Cancelled" }
|
||||
: { ...part, state: "output-available" as const, output: "" }
|
||||
: part,
|
||||
),
|
||||
}));
|
||||
}
|
||||
|
||||
export function useCopilotPage() {
|
||||
const { isUserLoading, isLoggedIn } = useSupabase();
|
||||
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
|
||||
@@ -115,7 +95,7 @@ export function useCopilotPage() {
|
||||
const {
|
||||
messages,
|
||||
sendMessage,
|
||||
stop: sdkStop,
|
||||
stop,
|
||||
status,
|
||||
error,
|
||||
setMessages,
|
||||
@@ -128,36 +108,6 @@ export function useCopilotPage() {
|
||||
// call resumeStream() manually after hydration + active_stream detection.
|
||||
});
|
||||
|
||||
// Wrap AI SDK's stop() to also cancel the backend executor task.
|
||||
// sdkStop() aborts the SSE fetch instantly (UI feedback), then we fire
|
||||
// the cancel API to actually stop the executor and wait for confirmation.
|
||||
async function stop() {
|
||||
sdkStop();
|
||||
setMessages((prev) => resolveInProgressTools(prev, "cancelled"));
|
||||
|
||||
if (!sessionId) return;
|
||||
try {
|
||||
const res = await postV2CancelSessionTask(sessionId);
|
||||
if (
|
||||
res.status === 200 &&
|
||||
"reason" in res.data &&
|
||||
res.data.reason === "cancel_published_not_confirmed"
|
||||
) {
|
||||
toast({
|
||||
title: "Stop may take a moment",
|
||||
description:
|
||||
"The cancel was sent but not yet confirmed. The task should stop shortly.",
|
||||
});
|
||||
}
|
||||
} catch {
|
||||
toast({
|
||||
title: "Could not stop the task",
|
||||
description: "The task may still be running in the background.",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Abort the stream if the backend doesn't start sending data within 12s.
|
||||
const stopRef = useRef(stop);
|
||||
stopRef.current = stop;
|
||||
@@ -202,18 +152,6 @@ export function useCopilotPage() {
|
||||
resumeStream();
|
||||
}, [hasActiveStream, sessionId, hydratedMessages, status, resumeStream]);
|
||||
|
||||
// When the stream finishes, resolve any tool parts still showing spinners.
|
||||
// This can happen if the backend didn't emit StreamToolOutputAvailable for
|
||||
// a tool call before sending StreamFinish (e.g. SDK built-in tools).
|
||||
const prevStatusRef = useRef(status);
|
||||
useEffect(() => {
|
||||
const prev = prevStatusRef.current;
|
||||
prevStatusRef.current = status;
|
||||
if (prev === "streaming" && status === "ready") {
|
||||
setMessages((msgs) => resolveInProgressTools(msgs, "completed"));
|
||||
}
|
||||
}, [status, setMessages]);
|
||||
|
||||
// Poll session endpoint when a long-running tool (create_agent, edit_agent)
|
||||
// is in progress. When the backend completes, the session data will contain
|
||||
// the final tool output — this hook detects the change and updates messages.
|
||||
|
||||
@@ -1052,7 +1052,6 @@
|
||||
{
|
||||
"$ref": "#/components/schemas/ClarificationNeededResponse"
|
||||
},
|
||||
{ "$ref": "#/components/schemas/SuggestedGoalResponse" },
|
||||
{ "$ref": "#/components/schemas/BlockListResponse" },
|
||||
{ "$ref": "#/components/schemas/BlockDetailsResponse" },
|
||||
{ "$ref": "#/components/schemas/BlockOutputResponse" },
|
||||
@@ -1263,44 +1262,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions/{session_id}/cancel": {
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Cancel Session Task",
|
||||
"description": "Cancel the active streaming task for a session.\n\nPublishes a cancel event to the executor via RabbitMQ FANOUT, then\npolls Redis until the task status flips from ``running`` or a timeout\n(5 s) is reached. Returns only after the cancellation is confirmed.",
|
||||
"operationId": "postV2CancelSessionTask",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "session_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Session Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/CancelTaskResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions/{session_id}/stream": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
@@ -7575,23 +7536,6 @@
|
||||
"required": ["file"],
|
||||
"title": "Body_postV2Upload submission media"
|
||||
},
|
||||
"CancelTaskResponse": {
|
||||
"properties": {
|
||||
"cancelled": { "type": "boolean", "title": "Cancelled" },
|
||||
"task_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Task Id"
|
||||
},
|
||||
"reason": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Reason"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["cancelled"],
|
||||
"title": "CancelTaskResponse",
|
||||
"description": "Response model for the cancel task endpoint."
|
||||
},
|
||||
"ChangelogEntry": {
|
||||
"properties": {
|
||||
"version": { "type": "string", "title": "Version" },
|
||||
@@ -10852,8 +10796,7 @@
|
||||
"bash_exec",
|
||||
"operation_status",
|
||||
"feature_request_search",
|
||||
"feature_request_created",
|
||||
"suggested_goal"
|
||||
"feature_request_created"
|
||||
],
|
||||
"title": "ResponseType",
|
||||
"description": "Types of tool responses."
|
||||
@@ -11734,47 +11677,6 @@
|
||||
"enum": ["DRAFT", "PENDING", "APPROVED", "REJECTED"],
|
||||
"title": "SubmissionStatus"
|
||||
},
|
||||
"SuggestedGoalResponse": {
|
||||
"properties": {
|
||||
"type": {
|
||||
"$ref": "#/components/schemas/ResponseType",
|
||||
"default": "suggested_goal"
|
||||
},
|
||||
"message": { "type": "string", "title": "Message" },
|
||||
"session_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Session Id"
|
||||
},
|
||||
"suggested_goal": {
|
||||
"type": "string",
|
||||
"title": "Suggested Goal",
|
||||
"description": "The suggested alternative goal"
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"title": "Reason",
|
||||
"description": "Why the original goal needs refinement",
|
||||
"default": ""
|
||||
},
|
||||
"original_goal": {
|
||||
"type": "string",
|
||||
"title": "Original Goal",
|
||||
"description": "The user's original goal for context",
|
||||
"default": ""
|
||||
},
|
||||
"goal_type": {
|
||||
"type": "string",
|
||||
"enum": ["vague", "unachievable"],
|
||||
"title": "Goal Type",
|
||||
"description": "Type: 'vague' or 'unachievable'",
|
||||
"default": "vague"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["message", "suggested_goal"],
|
||||
"title": "SuggestedGoalResponse",
|
||||
"description": "Response when the goal needs refinement with a suggested alternative."
|
||||
},
|
||||
"SuggestionsResponse": {
|
||||
"properties": {
|
||||
"otto_suggestions": {
|
||||
|
||||
Reference in New Issue
Block a user