mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-13 08:14:58 -05:00
Compare commits
96 Commits
pwuts/open
...
feat/copil
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60c6c14211 | ||
|
|
a86878ae6f | ||
|
|
80e413b969 | ||
|
|
52c8a25531 | ||
|
|
d0f0c32e70 | ||
|
|
8dfd0a77a0 | ||
|
|
4bfd6c8870 | ||
|
|
1918828405 | ||
|
|
9c855b501b | ||
|
|
5c9d0577c0 | ||
|
|
a79bd88e7c | ||
|
|
28c1121a8f | ||
|
|
cb3839198c | ||
|
|
80804986b0 | ||
|
|
b915e67a9b | ||
|
|
32e9dda30d | ||
|
|
cb45e7957b | ||
|
|
f1d02fb8f3 | ||
|
|
47de6b6420 | ||
|
|
62cd2eea89 | ||
|
|
ae61ec692e | ||
|
|
9296bd8736 | ||
|
|
308113c03d | ||
|
|
51abf13254 | ||
|
|
54b03d3a29 | ||
|
|
239dff5ebd | ||
|
|
1dd53db21c | ||
|
|
06c16ee2fe | ||
|
|
8d2a649ee5 | ||
|
|
9589474709 | ||
|
|
749a78723a | ||
|
|
bec2e1ddee | ||
|
|
ec1ab06e0d | ||
|
|
f31cb49557 | ||
|
|
fd28c386f4 | ||
|
|
3bea584659 | ||
|
|
d7f7a2747f | ||
|
|
68849e197c | ||
|
|
211478bb29 | ||
|
|
0e88dd15b2 | ||
|
|
7f3c227f0a | ||
|
|
40b58807ab | ||
|
|
d0e2e6f013 | ||
|
|
efdc8d73cc | ||
|
|
a34810d8a2 | ||
|
|
038b7d5841 | ||
|
|
cac93b0cc9 | ||
|
|
2025aaf5f2 | ||
|
|
ae9bce3bae | ||
|
|
3107d889fc | ||
|
|
f174fb6303 | ||
|
|
920a4c5f15 | ||
|
|
e95fadbb86 | ||
|
|
b14b3803ad | ||
|
|
82c483d6c8 | ||
|
|
7cffa1895f | ||
|
|
9791bdd724 | ||
|
|
750a674c78 | ||
|
|
960c7980a3 | ||
|
|
e85d437bb2 | ||
|
|
44f9536bd6 | ||
|
|
1c1085a227 | ||
|
|
d7ef70469e | ||
|
|
1926127ddd | ||
|
|
8b509e56de | ||
|
|
acb2d0bd1b | ||
|
|
51aa369c80 | ||
|
|
6403ffe353 | ||
|
|
c40a98ba3c | ||
|
|
a31fc8b162 | ||
|
|
0f2d1a6553 | ||
|
|
87d817b83b | ||
|
|
acf932bf4f | ||
|
|
f562d9a277 | ||
|
|
3c92a96504 | ||
|
|
8b8e1df739 | ||
|
|
602a0a4fb1 | ||
|
|
8d7d531ae0 | ||
|
|
43153a12e0 | ||
|
|
587e11c60a | ||
|
|
57da545e02 | ||
|
|
626980bf27 | ||
|
|
e42b27af3c | ||
|
|
34face15d2 | ||
|
|
7d32c83f95 | ||
|
|
6e2a45b84e | ||
|
|
32f6532e9c | ||
|
|
0bbe8a184d | ||
|
|
7592deed63 | ||
|
|
b9c759ce4f | ||
|
|
5efb80d47b | ||
|
|
b49d8e2cba | ||
|
|
452544530d | ||
|
|
32ee7e6cf8 | ||
|
|
670663c406 | ||
|
|
0dbe4cf51e |
@@ -66,13 +66,19 @@ ENV POETRY_HOME=/opt/poetry \
|
||||
DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
||||
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
||||
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
||||
# for the bash_exec MCP tool.
|
||||
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.13 \
|
||||
python3-pip \
|
||||
ffmpeg \
|
||||
imagemagick \
|
||||
jq \
|
||||
ripgrep \
|
||||
tree \
|
||||
bubblewrap \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
|
||||
@@ -1,9 +1,4 @@
|
||||
"""Common test fixtures for server tests.
|
||||
|
||||
Note: Common fixtures like test_user_id, admin_user_id, target_user_id,
|
||||
setup_test_user, and setup_admin_user are defined in the parent conftest.py
|
||||
(backend/conftest.py) and are available here automatically.
|
||||
"""
|
||||
"""Common test fixtures for server tests."""
|
||||
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
@@ -16,6 +11,54 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
||||
return snapshot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID fixture."""
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
"""Admin user ID fixture."""
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_id() -> str:
|
||||
"""Target user ID fixture."""
|
||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user(test_user_id):
|
||||
"""Create test user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the test user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": test_user_id,
|
||||
"email": "test@example.com",
|
||||
"user_metadata": {"name": "Test User"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_admin_user(admin_user_id):
|
||||
"""Create admin user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the admin user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": admin_user_id,
|
||||
"email": "test-admin@example.com",
|
||||
"user_metadata": {"name": "Test Admin"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return admin_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_user(test_user_id):
|
||||
"""Provide mock JWT payload for regular user testing."""
|
||||
|
||||
@@ -15,9 +15,9 @@ from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools import find_agent_tool, run_agent_tool
|
||||
from backend.copilot.tools.models import ToolResponseBase
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
|
||||
from backend.api.features.chat.tools.models import ToolResponseBase
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -119,9 +119,8 @@ class ChatCompletionConsumer:
|
||||
"""Lazily initialize Prisma client on first use."""
|
||||
if self._prisma is None:
|
||||
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||
prisma = Prisma(datasource={"url": database_url})
|
||||
await prisma.connect()
|
||||
self._prisma = prisma
|
||||
self._prisma = Prisma(datasource={"url": database_url})
|
||||
await self._prisma.connect()
|
||||
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
||||
return self._prisma
|
||||
|
||||
@@ -27,12 +27,11 @@ class ChatConfig(BaseSettings):
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# Streaming Configuration
|
||||
max_context_messages: int = Field(
|
||||
default=50, ge=1, le=200, description="Maximum context messages"
|
||||
)
|
||||
|
||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Max retries for fallback path (SDK handles retries internally)",
|
||||
)
|
||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||
max_agent_schedules: int = Field(
|
||||
default=30, description="Maximum number of agent schedules"
|
||||
@@ -93,6 +92,31 @@ class ChatConfig(BaseSettings):
|
||||
description="Name of the prompt in Langfuse to fetch",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
description="Use Claude Agent SDK for chat completions",
|
||||
)
|
||||
claude_agent_model: str | None = Field(
|
||||
default=None,
|
||||
description="Model for the Claude Agent SDK path. If None, derives from "
|
||||
"the `model` field by stripping the OpenRouter provider prefix.",
|
||||
)
|
||||
claude_agent_max_buffer_size: int = Field(
|
||||
default=10 * 1024 * 1024, # 10MB (default SDK is 1MB)
|
||||
description="Max buffer size in bytes for Claude Agent SDK JSON message parsing. "
|
||||
"Increase if tool outputs exceed the limit.",
|
||||
)
|
||||
claude_agent_max_subtasks: int = Field(
|
||||
default=10,
|
||||
description="Max number of sub-agent Tasks the SDK can spawn per session.",
|
||||
)
|
||||
claude_agent_use_resume: bool = Field(
|
||||
default=True,
|
||||
description="Use --resume for multi-turn conversations instead of "
|
||||
"history compression. Falls back to compression when unavailable.",
|
||||
)
|
||||
|
||||
# Extended thinking configuration for Claude models
|
||||
thinking_enabled: bool = Field(
|
||||
default=True,
|
||||
@@ -138,6 +162,17 @@ class ChatConfig(BaseSettings):
|
||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||
return v
|
||||
|
||||
@field_validator("use_claude_agent_sdk", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_agent_sdk(cls, v):
|
||||
"""Get use_claude_agent_sdk from environment if not provided."""
|
||||
# Check environment variable - default to True if not set
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
# Default to True (SDK enabled by default)
|
||||
return True if v is None else v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
@@ -14,7 +14,7 @@ from prisma.types import (
|
||||
ChatSessionWhereInput,
|
||||
)
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.db import transaction
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -147,7 +147,7 @@ async def add_chat_messages_batch(
|
||||
|
||||
created_messages = []
|
||||
|
||||
async with db.transaction() as tx:
|
||||
async with transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
@@ -23,17 +23,26 @@ from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db_accessors import chat_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util import json
|
||||
from backend.util.exceptions import DatabaseError, RedisError
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
# Redis cache key prefix for chat sessions
|
||||
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
|
||||
|
||||
@@ -43,7 +52,28 @@ def _get_session_cache_key(session_id: str) -> str:
|
||||
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
|
||||
|
||||
|
||||
# ===================== Chat data models ===================== #
|
||||
# Session-level locks to prevent race conditions during concurrent upserts.
|
||||
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
|
||||
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
|
||||
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
|
||||
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
@@ -292,28 +322,39 @@ class ChatSession(BaseModel):
|
||||
return self._merge_consecutive_assistant_messages(messages)
|
||||
|
||||
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
|
||||
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
|
||||
# ================ Chat cache + DB operations ================ #
|
||||
|
||||
# NOTE: Database calls are automatically routed through DatabaseManager if Prisma is not
|
||||
# connected directly.
|
||||
|
||||
|
||||
async def cache_chat_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session in Redis (without persisting to the database)."""
|
||||
async def _cache_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session in Redis."""
|
||||
redis_key = _get_session_cache_key(session.session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
|
||||
|
||||
async def cache_chat_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session without persisting to the database."""
|
||||
await _cache_session(session)
|
||||
|
||||
|
||||
async def invalidate_session_cache(session_id: str) -> None:
|
||||
"""Invalidate a chat session from Redis cache.
|
||||
|
||||
@@ -329,6 +370,77 @@ async def invalidate_session_cache(session_id: str) -> None:
|
||||
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
if not prisma_session:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.debug(
|
||||
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
|
||||
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
|
||||
)
|
||||
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
# Check if session exists in DB
|
||||
existing = await chat_db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await chat_db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
|
||||
f"roles={[m['role'] for m in messages_data]}"
|
||||
)
|
||||
await chat_db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
@@ -360,7 +472,7 @@ async def get_chat_session(
|
||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||
|
||||
# Fall back to database
|
||||
logger.info(f"Session {session_id} not in cache, checking database")
|
||||
logger.debug(f"Session {session_id} not in cache, checking database")
|
||||
session = await _get_session_from_db(session_id)
|
||||
|
||||
if session is None:
|
||||
@@ -376,53 +488,13 @@ async def get_chat_session(
|
||||
|
||||
# Cache the session from DB
|
||||
try:
|
||||
await cache_chat_session(session)
|
||||
logger.info(f"Cached session {session_id} from database")
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db().get_chat_session(session_id)
|
||||
if not prisma_session:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
|
||||
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
) -> ChatSession:
|
||||
@@ -443,7 +515,7 @@ async def upsert_chat_session(
|
||||
|
||||
async with lock:
|
||||
# Get existing message count from DB for incremental saves
|
||||
existing_message_count = await chat_db().get_chat_session_message_count(
|
||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
|
||||
@@ -460,7 +532,7 @@ async def upsert_chat_session(
|
||||
|
||||
# Save to cache (best-effort, even if DB failed)
|
||||
try:
|
||||
await cache_chat_session(session)
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
# If DB succeeded but cache failed, raise cache error
|
||||
if db_error is None:
|
||||
@@ -481,63 +553,38 @@ async def upsert_chat_session(
|
||||
return session
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
db = chat_db()
|
||||
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
||||
"""Atomically append a message to a session and persist it.
|
||||
|
||||
# Check if session exists in DB
|
||||
existing = await db.get_chat_session(session.session_id)
|
||||
Acquires the session lock, re-fetches the latest session state,
|
||||
appends the message, and saves — preventing message loss when
|
||||
concurrent requests modify the same session.
|
||||
"""
|
||||
lock = await _get_session_lock(session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
async with lock:
|
||||
session = await get_chat_session(session_id)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
session.messages.append(message)
|
||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||
session_id
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to persist message to session {session_id}"
|
||||
) from e
|
||||
|
||||
# Update session metadata
|
||||
await db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
try:
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str) -> ChatSession:
|
||||
@@ -552,7 +599,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
||||
|
||||
# Create in database first - fail fast if this fails
|
||||
try:
|
||||
await chat_db().create_chat_session(
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -564,7 +611,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
||||
|
||||
# Cache the session (best-effort optimization, DB is source of truth)
|
||||
try:
|
||||
await cache_chat_session(session)
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
|
||||
|
||||
@@ -582,9 +629,8 @@ async def get_user_sessions(
|
||||
A tuple of (sessions, total_count) where total_count is the overall
|
||||
number of sessions for the user (not just the current page).
|
||||
"""
|
||||
db = chat_db()
|
||||
prisma_sessions = await db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await db.get_user_session_count(user_id)
|
||||
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await chat_db.get_user_session_count(user_id)
|
||||
|
||||
sessions = []
|
||||
for prisma_session in prisma_sessions:
|
||||
@@ -607,7 +653,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
"""
|
||||
# Delete from database first (with optional user_id validation)
|
||||
# This confirms ownership before invalidating cache
|
||||
deleted = await chat_db().delete_chat_session(session_id, user_id)
|
||||
deleted = await chat_db.delete_chat_session(session_id, user_id)
|
||||
|
||||
if not deleted:
|
||||
return False
|
||||
@@ -642,46 +688,26 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
||||
True if updated successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
result = await chat_db().update_chat_session(session_id=session_id, title=title)
|
||||
result = await chat_db.update_chat_session(session_id=session_id, title=title)
|
||||
if result is None:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
return False
|
||||
|
||||
# Invalidate cache so next fetch gets updated title
|
||||
# Update title in cache if it exists (instead of invalidating).
|
||||
# This prevents race conditions where cache invalidation causes
|
||||
# the frontend to see stale DB data while streaming is still in progress.
|
||||
try:
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
cached = await _get_session_from_cache(session_id)
|
||||
if cached:
|
||||
cached.title = title
|
||||
await _cache_session(cached)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||
# Not critical - title will be correct on next full cache refresh
|
||||
logger.warning(
|
||||
f"Failed to update title in cache for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update title for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ==================== Chat session locks ==================== #
|
||||
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
This was originally added to solve the specific problem of race conditions between
|
||||
the session title thread and the conversation thread, which always occurs on the
|
||||
same instance as we prevent rapid request sends on the frontend.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks. Explicit cleanup also occurs
|
||||
in `delete_chat_session()`.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Chat API routes for chat session management and streaming via SSE."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid as uuid_module
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -10,22 +11,24 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response,
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.completion_handler import (
|
||||
process_operation_failure,
|
||||
process_operation_success,
|
||||
)
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.executor.utils import enqueue_copilot_task
|
||||
from backend.copilot.model import (
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
append_and_save_message,
|
||||
create_chat_session,
|
||||
get_chat_session,
|
||||
get_user_sessions,
|
||||
)
|
||||
from backend.copilot.response_model import StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.models import (
|
||||
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
|
||||
from .sdk import service as sdk_service
|
||||
from .tools.models import (
|
||||
AgentDetailsResponse,
|
||||
AgentOutputResponse,
|
||||
AgentPreviewResponse,
|
||||
@@ -48,7 +51,7 @@ from backend.copilot.tools.models import (
|
||||
SetupRequirementsResponse,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from .tracking import track_user_message
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -240,6 +243,10 @@ async def get_session(
|
||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_task:
|
||||
# Filter out the in-progress assistant message from the session response.
|
||||
# The client will receive the complete assistant response through the SSE
|
||||
@@ -309,10 +316,9 @@ async def stream_chat_post(
|
||||
f"user={user_id}, message_len={len(request.message)}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
|
||||
_session = await _validate_and_get_session(session_id, user_id) # noqa: F841
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
logger.info(
|
||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
@@ -321,6 +327,25 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
||||
# message loss from concurrent requests.
|
||||
if request.message:
|
||||
message = ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
)
|
||||
if request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||
session = await append_and_save_message(session_id, message)
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
task_id = str(uuid_module.uuid4())
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
@@ -336,7 +361,7 @@ async def stream_chat_post(
|
||||
operation_id=operation_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
|
||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
@@ -345,20 +370,125 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Enqueue the task to RabbitMQ for processing by the CoPilot executor
|
||||
await enqueue_copilot_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
message=request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
)
|
||||
# Background task that runs the AI generation independently of SSE connection
|
||||
async def run_ai_generation():
|
||||
import time as time_module
|
||||
|
||||
gen_start_time = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
first_chunk_time, ttfc = None, None
|
||||
chunk_count = 0
|
||||
try:
|
||||
# Emit a start event with task_id for reconnection
|
||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||
logger.info(
|
||||
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
|
||||
* 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Choose service based on LaunchDarkly flag (falls back to config default)
|
||||
use_sdk = await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else chat_service.stream_chat_completion
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
# Pass message=None since we already added it to the session above
|
||||
async for chunk in stream_fn(
|
||||
session_id,
|
||||
None, # Message already in session
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass session with message already added
|
||||
context=request.context,
|
||||
):
|
||||
# Skip duplicate StreamStart — we already published one above
|
||||
if isinstance(chunk, StreamStart):
|
||||
continue
|
||||
chunk_count += 1
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time_module.perf_counter()
|
||||
ttfc = first_chunk_time - gen_start_time
|
||||
logger.info(
|
||||
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
"time_to_first_chunk_ms": ttfc * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
# Write to Redis (subscribers will receive via XREAD)
|
||||
await stream_registry.publish_chunk(task_id, chunk)
|
||||
|
||||
gen_end_time = time_module.perf_counter()
|
||||
total_time = (gen_end_time - gen_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; "
|
||||
f"task={task_id}, session={session_id}, "
|
||||
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"time_to_first_chunk_ms": (
|
||||
ttfc * 1000 if ttfc is not None else None
|
||||
),
|
||||
"n_chunks": chunk_count,
|
||||
}
|
||||
},
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "completed")
|
||||
except Exception as e:
|
||||
elapsed = time_module.perf_counter() - gen_start_time
|
||||
logger.error(
|
||||
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed * 1000,
|
||||
"error": str(e),
|
||||
}
|
||||
},
|
||||
)
|
||||
# Publish a StreamError so the frontend can display an error message
|
||||
try:
|
||||
await stream_registry.publish_chunk(
|
||||
task_id,
|
||||
StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="stream_error",
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
pass # Best-effort; mark_task_completed will publish StreamFinish
|
||||
await stream_registry.mark_task_completed(task_id, "failed")
|
||||
|
||||
# Start the AI generation in a background task
|
||||
bg_task = asyncio.create_task(run_ai_generation())
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
|
||||
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||
)
|
||||
|
||||
@@ -453,8 +583,14 @@ async def stream_chat_post(
|
||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||
},
|
||||
)
|
||||
# Surface error to frontend so it doesn't appear stuck
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="stream_error",
|
||||
).to_sse()
|
||||
yield StreamFinish().to_sse()
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
@@ -698,8 +834,6 @@ async def stream_task(
|
||||
)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import asyncio
|
||||
|
||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||
try:
|
||||
while True:
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
"""Claude Agent SDK integration for CoPilot.
|
||||
|
||||
This module provides the integration layer between the Claude Agent SDK
|
||||
and the existing CoPilot tool system, enabling drop-in replacement of
|
||||
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
||||
"""
|
||||
|
||||
from .service import stream_chat_completion_sdk
|
||||
from .tool_adapter import create_copilot_mcp_server
|
||||
|
||||
__all__ = [
|
||||
"stream_chat_completion_sdk",
|
||||
"create_copilot_mcp_server",
|
||||
]
|
||||
@@ -0,0 +1,203 @@
|
||||
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||
|
||||
This module provides the adapter layer that converts streaming messages from
|
||||
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
|
||||
the frontend expects.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
Message,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from backend.api.features.chat.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.api.features.chat.sdk.tool_adapter import (
|
||||
MCP_TOOL_PREFIX,
|
||||
pop_pending_tool_output,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SDKResponseAdapter:
|
||||
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||
|
||||
This class maintains state during a streaming session to properly track
|
||||
text blocks, tool calls, and message lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(self, message_id: str | None = None):
|
||||
self.message_id = message_id or str(uuid.uuid4())
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_started_text = False
|
||||
self.has_ended_text = False
|
||||
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
||||
self.task_id: str | None = None
|
||||
self.step_open = False
|
||||
|
||||
def set_task_id(self, task_id: str) -> None:
|
||||
"""Set the task ID for reconnection support."""
|
||||
self.task_id = task_id
|
||||
|
||||
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
|
||||
"""Convert a single SDK message to Vercel AI SDK format."""
|
||||
responses: list[StreamBaseResponse] = []
|
||||
|
||||
if isinstance(sdk_message, SystemMessage):
|
||||
if sdk_message.subtype == "init":
|
||||
responses.append(
|
||||
StreamStart(messageId=self.message_id, taskId=self.task_id)
|
||||
)
|
||||
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
|
||||
responses.append(StreamStartStep())
|
||||
self.step_open = True
|
||||
|
||||
elif isinstance(sdk_message, AssistantMessage):
|
||||
# After tool results, the SDK sends a new AssistantMessage for the
|
||||
# next LLM turn. Open a new step if the previous one was closed.
|
||||
if not self.step_open:
|
||||
responses.append(StreamStartStep())
|
||||
self.step_open = True
|
||||
|
||||
for block in sdk_message.content:
|
||||
if isinstance(block, TextBlock):
|
||||
if block.text:
|
||||
self._ensure_text_started(responses)
|
||||
responses.append(
|
||||
StreamTextDelta(id=self.text_block_id, delta=block.text)
|
||||
)
|
||||
|
||||
elif isinstance(block, ToolUseBlock):
|
||||
self._end_text_if_open(responses)
|
||||
|
||||
# Strip MCP prefix so frontend sees "find_block"
|
||||
# instead of "mcp__copilot__find_block".
|
||||
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
|
||||
|
||||
responses.append(
|
||||
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
|
||||
)
|
||||
responses.append(
|
||||
StreamToolInputAvailable(
|
||||
toolCallId=block.id,
|
||||
toolName=tool_name,
|
||||
input=block.input,
|
||||
)
|
||||
)
|
||||
self.current_tool_calls[block.id] = {"name": tool_name}
|
||||
|
||||
elif isinstance(sdk_message, UserMessage):
|
||||
# UserMessage carries tool results back from tool execution.
|
||||
content = sdk_message.content
|
||||
blocks = content if isinstance(content, list) else []
|
||||
for block in blocks:
|
||||
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
||||
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
|
||||
# Prefer the stashed full output over the SDK's
|
||||
# (potentially truncated) ToolResultBlock content.
|
||||
# The SDK truncates large results, writing them to disk,
|
||||
# which breaks frontend widget parsing.
|
||||
output = pop_pending_tool_output(tool_name) or (
|
||||
_extract_tool_output(block.content)
|
||||
)
|
||||
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=block.tool_use_id,
|
||||
toolName=tool_name,
|
||||
output=output,
|
||||
success=not (block.is_error or False),
|
||||
)
|
||||
)
|
||||
|
||||
# Close the current step after tool results — the next
|
||||
# AssistantMessage will open a new step for the continuation.
|
||||
if self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
|
||||
elif isinstance(sdk_message, ResultMessage):
|
||||
self._end_text_if_open(responses)
|
||||
# Close the step before finishing.
|
||||
if self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
|
||||
if sdk_message.subtype == "success":
|
||||
responses.append(StreamFinish())
|
||||
elif sdk_message.subtype in ("error", "error_during_execution"):
|
||||
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
|
||||
responses.append(
|
||||
StreamError(errorText=str(error_msg), code="sdk_error")
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
|
||||
else:
|
||||
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
||||
|
||||
return responses
|
||||
|
||||
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""Start (or restart) a text block if needed."""
|
||||
if not self.has_started_text or self.has_ended_text:
|
||||
if self.has_ended_text:
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_ended_text = False
|
||||
responses.append(StreamTextStart(id=self.text_block_id))
|
||||
self.has_started_text = True
|
||||
|
||||
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""End the current text block if one is open."""
|
||||
if self.has_started_text and not self.has_ended_text:
|
||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||
self.has_ended_text = True
|
||||
|
||||
|
||||
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
||||
"""Extract a string output from a ToolResultBlock's content field."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
||||
if parts:
|
||||
return "".join(parts)
|
||||
try:
|
||||
return json.dumps(content)
|
||||
except (TypeError, ValueError):
|
||||
return str(content)
|
||||
if content is None:
|
||||
return ""
|
||||
try:
|
||||
return json.dumps(content)
|
||||
except (TypeError, ValueError):
|
||||
return str(content)
|
||||
@@ -0,0 +1,366 @@
|
||||
"""Unit tests for the SDK response adapter."""
|
||||
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from backend.api.features.chat.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .tool_adapter import MCP_TOOL_PREFIX
|
||||
|
||||
|
||||
def _adapter() -> SDKResponseAdapter:
|
||||
a = SDKResponseAdapter(message_id="msg-1")
|
||||
a.set_task_id("task-1")
|
||||
return a
|
||||
|
||||
|
||||
# -- SystemMessage -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_system_init_emits_start_and_step():
|
||||
adapter = _adapter()
|
||||
results = adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamStart)
|
||||
assert results[0].messageId == "msg-1"
|
||||
assert results[0].taskId == "task-1"
|
||||
assert isinstance(results[1], StreamStartStep)
|
||||
|
||||
|
||||
def test_system_non_init_emits_nothing():
|
||||
adapter = _adapter()
|
||||
results = adapter.convert_message(SystemMessage(subtype="other", data={}))
|
||||
assert results == []
|
||||
|
||||
|
||||
# -- AssistantMessage with TextBlock -----------------------------------------
|
||||
|
||||
|
||||
def test_text_block_emits_step_start_and_delta():
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
||||
results = adapter.convert_message(msg)
|
||||
assert len(results) == 3
|
||||
assert isinstance(results[0], StreamStartStep)
|
||||
assert isinstance(results[1], StreamTextStart)
|
||||
assert isinstance(results[2], StreamTextDelta)
|
||||
assert results[2].delta == "hello"
|
||||
|
||||
|
||||
def test_empty_text_block_emits_only_step():
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(content=[TextBlock(text="")], model="test")
|
||||
results = adapter.convert_message(msg)
|
||||
# Empty text skipped, but step still opens
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], StreamStartStep)
|
||||
|
||||
|
||||
def test_multiple_text_deltas_reuse_block_id():
|
||||
adapter = _adapter()
|
||||
msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test")
|
||||
msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test")
|
||||
r1 = adapter.convert_message(msg1)
|
||||
r2 = adapter.convert_message(msg2)
|
||||
# First gets step+start+delta, second only delta (block & step already started)
|
||||
assert len(r1) == 3
|
||||
assert isinstance(r1[0], StreamStartStep)
|
||||
assert isinstance(r1[1], StreamTextStart)
|
||||
assert len(r2) == 1
|
||||
assert isinstance(r2[0], StreamTextDelta)
|
||||
assert r1[1].id == r2[0].id # same block ID
|
||||
|
||||
|
||||
# -- AssistantMessage with ToolUseBlock --------------------------------------
|
||||
|
||||
|
||||
def test_tool_use_emits_input_start_and_available():
|
||||
"""Tool names arrive with MCP prefix and should be stripped for the frontend."""
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(
|
||||
id="tool-1",
|
||||
name=f"{MCP_TOOL_PREFIX}find_agent",
|
||||
input={"q": "x"},
|
||||
)
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
assert len(results) == 3
|
||||
assert isinstance(results[0], StreamStartStep)
|
||||
assert isinstance(results[1], StreamToolInputStart)
|
||||
assert results[1].toolCallId == "tool-1"
|
||||
assert results[1].toolName == "find_agent" # prefix stripped
|
||||
assert isinstance(results[2], StreamToolInputAvailable)
|
||||
assert results[2].toolName == "find_agent" # prefix stripped
|
||||
assert results[2].input == {"q": "x"}
|
||||
|
||||
|
||||
def test_text_then_tool_ends_text_block():
|
||||
adapter = _adapter()
|
||||
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
|
||||
tool_msg = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
||||
model="test",
|
||||
)
|
||||
adapter.convert_message(text_msg) # opens step + text
|
||||
results = adapter.convert_message(tool_msg)
|
||||
# Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable
|
||||
assert len(results) == 3
|
||||
assert isinstance(results[0], StreamTextEnd)
|
||||
assert isinstance(results[1], StreamToolInputStart)
|
||||
|
||||
|
||||
# -- UserMessage with ToolResultBlock ----------------------------------------
|
||||
|
||||
|
||||
def test_tool_result_emits_output_and_finish_step():
|
||||
adapter = _adapter()
|
||||
# First register the tool call (opens step) — SDK sends prefixed name
|
||||
tool_msg = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})],
|
||||
model="test",
|
||||
)
|
||||
adapter.convert_message(tool_msg)
|
||||
|
||||
# Now send tool result
|
||||
result_msg = UserMessage(
|
||||
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
|
||||
)
|
||||
results = adapter.convert_message(result_msg)
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||
assert results[0].toolCallId == "t1"
|
||||
assert results[0].toolName == "find_agent" # prefix stripped
|
||||
assert results[0].output == "found 3 agents"
|
||||
assert results[0].success is True
|
||||
assert isinstance(results[1], StreamFinishStep)
|
||||
|
||||
|
||||
def test_tool_result_error():
|
||||
adapter = _adapter()
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
result_msg = UserMessage(
|
||||
content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)]
|
||||
)
|
||||
results = adapter.convert_message(result_msg)
|
||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||
assert results[0].success is False
|
||||
assert isinstance(results[1], StreamFinishStep)
|
||||
|
||||
|
||||
def test_tool_result_list_content():
|
||||
adapter = _adapter()
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
result_msg = UserMessage(
|
||||
content=[
|
||||
ToolResultBlock(
|
||||
tool_use_id="t1",
|
||||
content=[
|
||||
{"type": "text", "text": "line1"},
|
||||
{"type": "text", "text": "line2"},
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
results = adapter.convert_message(result_msg)
|
||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||
assert results[0].output == "line1line2"
|
||||
assert isinstance(results[1], StreamFinishStep)
|
||||
|
||||
|
||||
def test_string_user_message_ignored():
|
||||
"""A plain string UserMessage (not tool results) produces no output."""
|
||||
adapter = _adapter()
|
||||
results = adapter.convert_message(UserMessage(content="hello"))
|
||||
assert results == []
|
||||
|
||||
|
||||
# -- ResultMessage -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_result_success_emits_finish_step_and_finish():
|
||||
adapter = _adapter()
|
||||
# Start some text first (opens step)
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="done")], model="test")
|
||||
)
|
||||
msg = ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
# TextEnd + FinishStep + StreamFinish
|
||||
assert len(results) == 3
|
||||
assert isinstance(results[0], StreamTextEnd)
|
||||
assert isinstance(results[1], StreamFinishStep)
|
||||
assert isinstance(results[2], StreamFinish)
|
||||
|
||||
|
||||
def test_result_error_emits_error_and_finish():
|
||||
adapter = _adapter()
|
||||
msg = ResultMessage(
|
||||
subtype="error",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=True,
|
||||
num_turns=0,
|
||||
session_id="s1",
|
||||
result="API rate limited",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
# No step was open, so no FinishStep — just Error + Finish
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamError)
|
||||
assert "API rate limited" in results[0].errorText
|
||||
assert isinstance(results[1], StreamFinish)
|
||||
|
||||
|
||||
# -- Text after tools (new block ID) ----------------------------------------
|
||||
|
||||
|
||||
def test_text_after_tool_gets_new_block_id():
|
||||
adapter = _adapter()
|
||||
# Text -> Tool -> ToolResult -> Text should get a new text block ID and step
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="before")], model="test")
|
||||
)
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
# Send tool result (closes step)
|
||||
adapter.convert_message(
|
||||
UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")])
|
||||
)
|
||||
results = adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="after")], model="test")
|
||||
)
|
||||
# Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta
|
||||
assert len(results) == 3
|
||||
assert isinstance(results[0], StreamStartStep)
|
||||
assert isinstance(results[1], StreamTextStart)
|
||||
assert isinstance(results[2], StreamTextDelta)
|
||||
assert results[2].delta == "after"
|
||||
|
||||
|
||||
# -- Full conversation flow --------------------------------------------------
|
||||
|
||||
|
||||
def test_full_conversation_flow():
|
||||
"""Simulate a complete conversation: init -> text -> tool -> result -> text -> finish."""
|
||||
adapter = _adapter()
|
||||
all_responses: list[StreamBaseResponse] = []
|
||||
|
||||
# 1. Init
|
||||
all_responses.extend(
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
)
|
||||
# 2. Assistant text
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="Let me search")], model="test")
|
||||
)
|
||||
)
|
||||
# 3. Tool use
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(
|
||||
id="t1",
|
||||
name=f"{MCP_TOOL_PREFIX}find_agent",
|
||||
input={"query": "email"},
|
||||
)
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
)
|
||||
# 4. Tool result
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
UserMessage(
|
||||
content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")]
|
||||
)
|
||||
)
|
||||
)
|
||||
# 5. More text
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="I found 2")], model="test")
|
||||
)
|
||||
)
|
||||
# 6. Result
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=500,
|
||||
duration_api_ms=400,
|
||||
is_error=False,
|
||||
num_turns=2,
|
||||
session_id="s1",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
types = [type(r).__name__ for r in all_responses]
|
||||
assert types == [
|
||||
"StreamStart",
|
||||
"StreamStartStep", # step 1: text + tool call
|
||||
"StreamTextStart",
|
||||
"StreamTextDelta", # "Let me search"
|
||||
"StreamTextEnd", # closed before tool
|
||||
"StreamToolInputStart",
|
||||
"StreamToolInputAvailable",
|
||||
"StreamToolOutputAvailable", # tool result
|
||||
"StreamFinishStep", # step 1 closed after tool result
|
||||
"StreamStartStep", # step 2: continuation text
|
||||
"StreamTextStart", # new block after tool
|
||||
"StreamTextDelta", # "I found 2"
|
||||
"StreamTextEnd", # closed by result
|
||||
"StreamFinishStep", # step 2 closed
|
||||
"StreamFinish",
|
||||
]
|
||||
@@ -0,0 +1,335 @@
|
||||
"""Security hooks for Claude Agent SDK integration.
|
||||
|
||||
This module provides security hooks that validate tool calls before execution,
|
||||
ensuring multi-user isolation and preventing unauthorized operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from backend.api.features.chat.sdk.tool_adapter import MCP_TOOL_PREFIX
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tools that are blocked entirely (CLI/system access).
|
||||
# "Bash" (capital) is the SDK built-in — it's NOT in allowed_tools but blocked
|
||||
# here as defence-in-depth. The agent uses mcp__copilot__bash_exec instead,
|
||||
# which has kernel-level network isolation (unshare --net).
|
||||
BLOCKED_TOOLS = {
|
||||
"Bash",
|
||||
"bash",
|
||||
"shell",
|
||||
"exec",
|
||||
"terminal",
|
||||
"command",
|
||||
}
|
||||
|
||||
# Tools allowed only when their path argument stays within the SDK workspace.
|
||||
# The SDK uses these to handle oversized tool results (writes to tool-results/
|
||||
# files, then reads them back) and for workspace file operations.
|
||||
WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"}
|
||||
|
||||
# Dangerous patterns in tool inputs
|
||||
DANGEROUS_PATTERNS = [
|
||||
r"sudo",
|
||||
r"rm\s+-rf",
|
||||
r"dd\s+if=",
|
||||
r"/etc/passwd",
|
||||
r"/etc/shadow",
|
||||
r"chmod\s+777",
|
||||
r"curl\s+.*\|.*sh",
|
||||
r"wget\s+.*\|.*sh",
|
||||
r"eval\s*\(",
|
||||
r"exec\s*\(",
|
||||
r"__import__",
|
||||
r"os\.system",
|
||||
r"subprocess",
|
||||
]
|
||||
|
||||
|
||||
def _deny(reason: str) -> dict[str, Any]:
|
||||
"""Return a hook denial response."""
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": reason,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _validate_workspace_path(
|
||||
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None
|
||||
) -> dict[str, Any]:
|
||||
"""Validate that a workspace-scoped tool only accesses allowed paths.
|
||||
|
||||
Allowed directories:
|
||||
- The SDK working directory (``/tmp/copilot-<session>/``)
|
||||
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
|
||||
"""
|
||||
path = tool_input.get("file_path") or tool_input.get("path") or ""
|
||||
if not path:
|
||||
# Glob/Grep without a path default to cwd which is already sandboxed
|
||||
return {}
|
||||
|
||||
# Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM
|
||||
# naturally uses relative paths like "test.txt" instead of absolute ones).
|
||||
# Tilde paths (~/) are home-dir references, not relative — expand first.
|
||||
if path.startswith("~"):
|
||||
resolved = os.path.realpath(os.path.expanduser(path))
|
||||
elif not os.path.isabs(path) and sdk_cwd:
|
||||
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
||||
else:
|
||||
resolved = os.path.realpath(path)
|
||||
|
||||
# Allow access within the SDK working directory
|
||||
if sdk_cwd:
|
||||
norm_cwd = os.path.realpath(sdk_cwd)
|
||||
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
|
||||
return {}
|
||||
|
||||
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
|
||||
claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
tool_results_seg = os.sep + "tool-results" + os.sep
|
||||
if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved:
|
||||
return {}
|
||||
|
||||
logger.warning(
|
||||
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
|
||||
)
|
||||
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
|
||||
f"directory.{workspace_hint} "
|
||||
"This is enforced by the platform and cannot be bypassed."
|
||||
)
|
||||
|
||||
|
||||
def _validate_tool_access(
|
||||
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Validate that a tool call is allowed.
|
||||
|
||||
Returns:
|
||||
Empty dict to allow, or dict with hookSpecificOutput to deny
|
||||
"""
|
||||
# Block forbidden tools
|
||||
if tool_name in BLOCKED_TOOLS:
|
||||
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
|
||||
"This is enforced by the platform and cannot be bypassed. "
|
||||
"Use the CoPilot-specific MCP tools instead."
|
||||
)
|
||||
|
||||
# Workspace-scoped tools: allowed only within the SDK workspace directory
|
||||
if tool_name in WORKSPACE_SCOPED_TOOLS:
|
||||
return _validate_workspace_path(tool_name, tool_input, sdk_cwd)
|
||||
|
||||
# Check for dangerous patterns in tool input
|
||||
# Use json.dumps for predictable format (str() produces Python repr)
|
||||
input_str = json.dumps(tool_input) if tool_input else ""
|
||||
|
||||
for pattern in DANGEROUS_PATTERNS:
|
||||
if re.search(pattern, input_str, re.IGNORECASE):
|
||||
logger.warning(
|
||||
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
|
||||
)
|
||||
return _deny(
|
||||
"[SECURITY] Input contains a blocked pattern. "
|
||||
"This is enforced by the platform and cannot be bypassed."
|
||||
)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _validate_user_isolation(
|
||||
tool_name: str, tool_input: dict[str, Any], user_id: str | None
|
||||
) -> dict[str, Any]:
|
||||
"""Validate that tool calls respect user isolation."""
|
||||
# For workspace file tools, ensure path doesn't escape
|
||||
if "workspace" in tool_name.lower():
|
||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||
if path:
|
||||
# Check for path traversal
|
||||
if ".." in path or path.startswith("/"):
|
||||
logger.warning(
|
||||
f"Blocked path traversal attempt: {path} by user {user_id}"
|
||||
)
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": "Path traversal not allowed",
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def create_security_hooks(
|
||||
user_id: str | None,
|
||||
sdk_cwd: str | None = None,
|
||||
max_subtasks: int = 3,
|
||||
on_stop: Callable[[str, str], None] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create the security hooks configuration for Claude Agent SDK.
|
||||
|
||||
Includes security validation and observability hooks:
|
||||
- PreToolUse: Security validation before tool execution
|
||||
- PostToolUse: Log successful tool executions
|
||||
- PostToolUseFailure: Log and handle failed tool executions
|
||||
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
||||
- Stop: Capture transcript path for stateless resume (when *on_stop* is provided)
|
||||
|
||||
Args:
|
||||
user_id: Current user ID for isolation validation
|
||||
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
||||
max_subtasks: Maximum Task (sub-agent) spawns allowed per session
|
||||
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
|
||||
the SDK finishes processing — used to read the JSONL transcript
|
||||
before the CLI process exits.
|
||||
|
||||
Returns:
|
||||
Hooks configuration dict for ClaudeAgentOptions
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk import HookMatcher
|
||||
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||
|
||||
# Per-session counter for Task sub-agent spawns
|
||||
task_spawn_count = 0
|
||||
|
||||
async def pre_tool_use_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Combined pre-tool-use validation hook."""
|
||||
nonlocal task_spawn_count
|
||||
_ = context # unused but required by signature
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
||||
|
||||
# Rate-limit Task (sub-agent) spawns per session
|
||||
if tool_name == "Task":
|
||||
task_spawn_count += 1
|
||||
if task_spawn_count > max_subtasks:
|
||||
logger.warning(
|
||||
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
|
||||
)
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
f"Maximum {max_subtasks} sub-tasks per session. "
|
||||
"Please continue in the main conversation."
|
||||
),
|
||||
)
|
||||
|
||||
# Strip MCP prefix for consistent validation
|
||||
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
|
||||
clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
|
||||
|
||||
# Only block non-CoPilot tools; our MCP-registered tools
|
||||
# (including Read for oversized results) are already sandboxed.
|
||||
if not is_copilot_tool:
|
||||
result = _validate_tool_access(clean_name, tool_input, sdk_cwd)
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
# Validate user isolation
|
||||
result = _validate_user_isolation(clean_name, tool_input, user_id)
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def post_tool_use_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log successful tool executions for observability."""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def post_tool_failure_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log failed tool executions for debugging."""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
error = input_data.get("error", "Unknown error")
|
||||
logger.warning(
|
||||
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
||||
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def pre_compact_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log when SDK triggers context compaction.
|
||||
|
||||
The SDK automatically compacts conversation history when it grows too large.
|
||||
This hook provides visibility into when compaction happens.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
trigger = input_data.get("trigger", "auto")
|
||||
logger.info(
|
||||
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
# --- Stop hook: capture transcript path for stateless resume ---
|
||||
async def stop_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Capture transcript path when SDK finishes processing.
|
||||
|
||||
The Stop hook fires while the CLI process is still alive, giving us
|
||||
a reliable window to read the JSONL transcript before SIGTERM.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
transcript_path = cast(str, input_data.get("transcript_path", ""))
|
||||
sdk_session_id = cast(str, input_data.get("session_id", ""))
|
||||
|
||||
if transcript_path and on_stop:
|
||||
logger.info(
|
||||
f"[SDK] Stop hook: transcript_path={transcript_path}, "
|
||||
f"sdk_session_id={sdk_session_id[:12]}..."
|
||||
)
|
||||
on_stop(transcript_path, sdk_session_id)
|
||||
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
hooks: dict[str, Any] = {
|
||||
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
||||
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
||||
"PostToolUseFailure": [
|
||||
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
|
||||
],
|
||||
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
||||
}
|
||||
|
||||
if on_stop is not None:
|
||||
hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])]
|
||||
|
||||
return hooks
|
||||
except ImportError:
|
||||
# Fallback for when SDK isn't available - return empty hooks
|
||||
logger.warning("claude-agent-sdk not available, security hooks disabled")
|
||||
return {}
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Unit tests for SDK security hooks."""
|
||||
|
||||
import os
|
||||
|
||||
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
||||
|
||||
SDK_CWD = "/tmp/copilot-abc123"
|
||||
|
||||
|
||||
def _is_denied(result: dict) -> bool:
|
||||
hook = result.get("hookSpecificOutput", {})
|
||||
return hook.get("permissionDecision") == "deny"
|
||||
|
||||
|
||||
# -- Blocked tools -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_blocked_tools_denied():
|
||||
for tool in ("bash", "shell", "exec", "terminal", "command"):
|
||||
result = _validate_tool_access(tool, {})
|
||||
assert _is_denied(result), f"{tool} should be blocked"
|
||||
|
||||
|
||||
def test_unknown_tool_allowed():
|
||||
result = _validate_tool_access("SomeCustomTool", {})
|
||||
assert result == {}
|
||||
|
||||
|
||||
# -- Workspace-scoped tools --------------------------------------------------
|
||||
|
||||
|
||||
def test_read_within_workspace_allowed():
|
||||
result = _validate_tool_access(
|
||||
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_write_within_workspace_allowed():
|
||||
result = _validate_tool_access(
|
||||
"Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_edit_within_workspace_allowed():
|
||||
result = _validate_tool_access(
|
||||
"Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_glob_within_workspace_allowed():
|
||||
result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_grep_within_workspace_allowed():
|
||||
result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_outside_workspace_denied():
|
||||
result = _validate_tool_access(
|
||||
"Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_write_outside_workspace_denied():
|
||||
result = _validate_tool_access(
|
||||
"Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_traversal_attack_denied():
|
||||
result = _validate_tool_access(
|
||||
"Read",
|
||||
{"file_path": f"{SDK_CWD}/../../etc/passwd"},
|
||||
sdk_cwd=SDK_CWD,
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_no_path_allowed():
|
||||
"""Glob/Grep without a path argument defaults to cwd — should pass."""
|
||||
result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_no_cwd_denies_absolute():
|
||||
"""If no sdk_cwd is set, absolute paths are denied."""
|
||||
result = _validate_tool_access("Read", {"file_path": "/tmp/anything"})
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- Tool-results directory --------------------------------------------------
|
||||
|
||||
|
||||
def test_read_tool_results_allowed():
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_claude_projects_without_tool_results_denied():
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------
|
||||
|
||||
|
||||
def test_bash_builtin_always_blocked():
|
||||
"""SDK built-in Bash is blocked — bash_exec MCP tool with bubblewrap is used instead."""
|
||||
result = _validate_tool_access("Bash", {"command": "echo hello"}, sdk_cwd=SDK_CWD)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- Dangerous patterns ------------------------------------------------------
|
||||
|
||||
|
||||
def test_dangerous_pattern_blocked():
|
||||
result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"})
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_subprocess_pattern_blocked():
|
||||
result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"})
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- User isolation ----------------------------------------------------------
|
||||
|
||||
|
||||
def test_workspace_path_traversal_blocked():
|
||||
result = _validate_user_isolation(
|
||||
"workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1"
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_workspace_absolute_path_blocked():
|
||||
result = _validate_user_isolation(
|
||||
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_workspace_normal_path_allowed():
|
||||
result = _validate_user_isolation(
|
||||
"workspace_read", {"path": "src/main.py"}, user_id="user-1"
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_non_workspace_tool_passes_isolation():
|
||||
result = _validate_user_isolation(
|
||||
"find_agent", {"query": "email"}, user_id="user-1"
|
||||
)
|
||||
assert result == {}
|
||||
@@ -0,0 +1,751 @@
|
||||
"""Claude Agent SDK service layer for CoPilot chat completions."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .. import stream_registry
|
||||
from ..config import ChatConfig
|
||||
from ..model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from ..service import (
|
||||
_build_system_prompt,
|
||||
_execute_long_running_tool_with_streaming,
|
||||
_generate_session_title,
|
||||
)
|
||||
from ..tools.models import OperationPendingResponse, OperationStartedResponse
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tracking import track_user_message
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .security_hooks import create_security_hooks
|
||||
from .tool_adapter import (
|
||||
COPILOT_TOOL_NAMES,
|
||||
LongRunningCallback,
|
||||
create_copilot_mcp_server,
|
||||
set_execution_context,
|
||||
)
|
||||
from .transcript import (
|
||||
download_transcript,
|
||||
read_transcript_file,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
# Set to hold background tasks to prevent garbage collection
|
||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
|
||||
@dataclass
|
||||
class CapturedTranscript:
|
||||
"""Info captured by the SDK Stop hook for stateless --resume."""
|
||||
|
||||
path: str = ""
|
||||
sdk_session_id: str = ""
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
return bool(self.path)
|
||||
|
||||
|
||||
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
|
||||
|
||||
# Appended to the system prompt to inform the agent about available tools.
|
||||
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
|
||||
# which has kernel-level network isolation (unshare --net).
|
||||
_SDK_TOOL_SUPPLEMENT = """
|
||||
|
||||
## Tool notes
|
||||
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs in a network-isolated sandbox.
|
||||
- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the
|
||||
same working directory. Files created by one are readable by the other.
|
||||
These files are **ephemeral** — they exist only for the current session.
|
||||
- **Persistent storage**: Use `write_workspace_file` / `read_workspace_file`
|
||||
for files that should persist across sessions (stored in cloud storage).
|
||||
- Long-running tools (create_agent, edit_agent, etc.) are handled
|
||||
asynchronously. You will receive an immediate response; the actual result
|
||||
is delivered to the user via a background stream.
|
||||
"""
|
||||
|
||||
|
||||
def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
|
||||
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
|
||||
|
||||
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
|
||||
existing background infrastructure: stream_registry (Redis Streams),
|
||||
database persistence, and SSE reconnection. This means results survive
|
||||
page refreshes / pod restarts, and the frontend shows the proper loading
|
||||
widget with progress updates.
|
||||
|
||||
The returned callback matches the ``LongRunningCallback`` signature:
|
||||
``(tool_name, args, session) -> MCP response dict``.
|
||||
"""
|
||||
|
||||
async def _callback(
|
||||
tool_name: str, args: dict[str, Any], session: ChatSession
|
||||
) -> dict[str, Any]:
|
||||
operation_id = str(uuid.uuid4())
|
||||
task_id = str(uuid.uuid4())
|
||||
tool_call_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
||||
session_id = session.session_id
|
||||
|
||||
# --- Build user-friendly messages (matches non-SDK service) ---
|
||||
if tool_name == "create_agent":
|
||||
desc = args.get("description", "")
|
||||
desc_preview = (desc[:100] + "...") if len(desc) > 100 else desc
|
||||
pending_msg = (
|
||||
f"Creating your agent: {desc_preview}"
|
||||
if desc_preview
|
||||
else "Creating agent... This may take a few minutes."
|
||||
)
|
||||
started_msg = (
|
||||
"Agent creation started. You can close this tab - "
|
||||
"check your library in a few minutes."
|
||||
)
|
||||
elif tool_name == "edit_agent":
|
||||
changes = args.get("changes", "")
|
||||
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
|
||||
pending_msg = (
|
||||
f"Editing agent: {changes_preview}"
|
||||
if changes_preview
|
||||
else "Editing agent... This may take a few minutes."
|
||||
)
|
||||
started_msg = (
|
||||
"Agent edit started. You can close this tab - "
|
||||
"check your library in a few minutes."
|
||||
)
|
||||
else:
|
||||
pending_msg = f"Running {tool_name}... This may take a few minutes."
|
||||
started_msg = (
|
||||
f"{tool_name} started. You can close this tab - "
|
||||
"check back in a few minutes."
|
||||
)
|
||||
|
||||
# --- Register task in Redis for SSE reconnection ---
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
operation_id=operation_id,
|
||||
)
|
||||
|
||||
# --- Save OperationPendingResponse to chat history ---
|
||||
pending_message = ChatMessage(
|
||||
role="tool",
|
||||
content=OperationPendingResponse(
|
||||
message=pending_msg,
|
||||
operation_id=operation_id,
|
||||
tool_name=tool_name,
|
||||
).model_dump_json(),
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
session.messages.append(pending_message)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
# --- Spawn background task (reuses non-SDK infrastructure) ---
|
||||
bg_task = asyncio.create_task(
|
||||
_execute_long_running_tool_with_streaming(
|
||||
tool_name=tool_name,
|
||||
parameters=args,
|
||||
tool_call_id=tool_call_id,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
_background_tasks.add(bg_task)
|
||||
bg_task.add_done_callback(_background_tasks.discard)
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
|
||||
logger.info(
|
||||
f"[SDK] Long-running tool {tool_name} delegated to background "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
|
||||
# --- Return OperationStartedResponse as MCP tool result ---
|
||||
# This flows through SDK → response adapter → frontend, triggering
|
||||
# the loading widget with SSE reconnection support.
|
||||
started_json = OperationStartedResponse(
|
||||
message=started_msg,
|
||||
operation_id=operation_id,
|
||||
tool_name=tool_name,
|
||||
task_id=task_id,
|
||||
).model_dump_json()
|
||||
|
||||
return {
|
||||
"content": [{"type": "text", "text": started_json}],
|
||||
"isError": False,
|
||||
}
|
||||
|
||||
return _callback
|
||||
|
||||
|
||||
def _resolve_sdk_model() -> str | None:
|
||||
"""Resolve the model name for the Claude Agent SDK CLI.
|
||||
|
||||
Uses ``config.claude_agent_model`` if set, otherwise derives from
|
||||
``config.model`` by stripping the OpenRouter provider prefix (e.g.,
|
||||
``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``).
|
||||
"""
|
||||
if config.claude_agent_model:
|
||||
return config.claude_agent_model
|
||||
model = config.model
|
||||
if "/" in model:
|
||||
return model.split("/", 1)[1]
|
||||
return model
|
||||
|
||||
|
||||
def _build_sdk_env() -> dict[str, str]:
|
||||
"""Build env vars for the SDK CLI process.
|
||||
|
||||
Routes API calls through OpenRouter (or a custom base_url) using
|
||||
the same ``config.api_key`` / ``config.base_url`` as the non-SDK path.
|
||||
This gives per-call token and cost tracking on the OpenRouter dashboard.
|
||||
|
||||
Only overrides ``ANTHROPIC_API_KEY`` when a valid proxy URL and auth
|
||||
token are both present — otherwise returns an empty dict so the SDK
|
||||
falls back to its default credentials.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
if config.api_key and config.base_url:
|
||||
# Strip /v1 suffix — SDK expects the base URL without a version path
|
||||
base = config.base_url.rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
if not base or not base.startswith("http"):
|
||||
# Invalid base_url — don't override SDK defaults
|
||||
return env
|
||||
env["ANTHROPIC_BASE_URL"] = base
|
||||
env["ANTHROPIC_AUTH_TOKEN"] = config.api_key
|
||||
# Must be explicitly empty so the CLI uses AUTH_TOKEN instead
|
||||
env["ANTHROPIC_API_KEY"] = ""
|
||||
return env
|
||||
|
||||
|
||||
def _make_sdk_cwd(session_id: str) -> str:
|
||||
"""Create a safe, session-specific working directory path.
|
||||
|
||||
Delegates to :func:`~backend.api.features.chat.tools.sandbox.make_session_path`
|
||||
(single source of truth for path sanitization) and adds a defence-in-depth
|
||||
assertion.
|
||||
"""
|
||||
cwd = make_session_path(session_id)
|
||||
# Defence-in-depth: normpath + startswith is a CodeQL-recognised sanitizer
|
||||
cwd = os.path.normpath(cwd)
|
||||
if not cwd.startswith(_SDK_CWD_PREFIX):
|
||||
raise ValueError(f"SDK cwd escaped prefix: {cwd}")
|
||||
return cwd
|
||||
|
||||
|
||||
def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||
"""Remove SDK tool-result files for a specific session working directory.
|
||||
|
||||
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
|
||||
We clean only the specific cwd's results to avoid race conditions between
|
||||
concurrent sessions.
|
||||
|
||||
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
# Validate cwd is under the expected prefix
|
||||
normalized = os.path.normpath(cwd)
|
||||
if not normalized.startswith(_SDK_CWD_PREFIX):
|
||||
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
|
||||
return
|
||||
|
||||
# SDK encodes the cwd path by replacing '/' with '-'
|
||||
encoded_cwd = normalized.replace("/", "-")
|
||||
|
||||
# Construct the project directory path (known-safe home expansion)
|
||||
claude_projects = os.path.expanduser("~/.claude/projects")
|
||||
project_dir = os.path.join(claude_projects, encoded_cwd)
|
||||
|
||||
# Security check 3: Validate project_dir is under ~/.claude/projects
|
||||
project_dir = os.path.normpath(project_dir)
|
||||
if not project_dir.startswith(claude_projects):
|
||||
logger.warning(
|
||||
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
|
||||
)
|
||||
return
|
||||
|
||||
results_dir = os.path.join(project_dir, "tool-results")
|
||||
if os.path.isdir(results_dir):
|
||||
for filename in os.listdir(results_dir):
|
||||
file_path = os.path.join(results_dir, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Also clean up the temp cwd directory itself
|
||||
try:
|
||||
shutil.rmtree(normalized, ignore_errors=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
async def _compress_conversation_history(
|
||||
session: ChatSession,
|
||||
) -> list[ChatMessage]:
|
||||
"""Compress prior conversation messages if they exceed the token threshold.
|
||||
|
||||
Uses the shared compress_context() from prompt.py which supports:
|
||||
- LLM summarization of old messages (keeps recent ones intact)
|
||||
- Progressive content truncation as fallback
|
||||
- Middle-out deletion as last resort
|
||||
|
||||
Returns the compressed prior messages (everything except the current message).
|
||||
"""
|
||||
prior = session.messages[:-1]
|
||||
if len(prior) < 2:
|
||||
return prior
|
||||
|
||||
from backend.util.prompt import compress_context
|
||||
|
||||
# Convert ChatMessages to dicts for compress_context
|
||||
messages_dict = []
|
||||
for msg in prior:
|
||||
msg_dict: dict[str, Any] = {"role": msg.role}
|
||||
if msg.content:
|
||||
msg_dict["content"] = msg.content
|
||||
if msg.tool_calls:
|
||||
msg_dict["tool_calls"] = msg.tool_calls
|
||||
if msg.tool_call_id:
|
||||
msg_dict["tool_call_id"] = msg.tool_call_id
|
||||
messages_dict.append(msg_dict)
|
||||
|
||||
try:
|
||||
import openai
|
||||
|
||||
async with openai.AsyncOpenAI(
|
||||
api_key=config.api_key, base_url=config.base_url, timeout=30.0
|
||||
) as client:
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
|
||||
# Fall back to truncation-only (no LLM summarization)
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=None,
|
||||
)
|
||||
|
||||
if result.was_compacted:
|
||||
logger.info(
|
||||
f"[SDK] Context compacted: {result.original_token_count} -> "
|
||||
f"{result.token_count} tokens "
|
||||
f"({result.messages_summarized} summarized, "
|
||||
f"{result.messages_dropped} dropped)"
|
||||
)
|
||||
# Convert compressed dicts back to ChatMessages
|
||||
return [
|
||||
ChatMessage(
|
||||
role=m["role"],
|
||||
content=m.get("content"),
|
||||
tool_calls=m.get("tool_calls"),
|
||||
tool_call_id=m.get("tool_call_id"),
|
||||
)
|
||||
for m in result.messages
|
||||
]
|
||||
|
||||
return prior
|
||||
|
||||
|
||||
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
||||
"""Format conversation messages into a context prefix for the user message.
|
||||
|
||||
Returns a string like:
|
||||
<conversation_history>
|
||||
User: hello
|
||||
You responded: Hi! How can I help?
|
||||
</conversation_history>
|
||||
|
||||
Returns None if there are no messages to format.
|
||||
"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
lines: list[str] = []
|
||||
for msg in messages:
|
||||
if not msg.content:
|
||||
continue
|
||||
if msg.role == "user":
|
||||
lines.append(f"User: {msg.content}")
|
||||
elif msg.role == "assistant":
|
||||
lines.append(f"You responded: {msg.content}")
|
||||
# Skip tool messages — they're internal details
|
||||
|
||||
if not lines:
|
||||
return None
|
||||
|
||||
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
|
||||
|
||||
|
||||
async def stream_chat_completion_sdk(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
tool_call_response: str | None = None, # noqa: ARG001
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
retry_count: int = 0, # noqa: ARG001
|
||||
session: ChatSession | None = None,
|
||||
context: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream chat completion using Claude Agent SDK.
|
||||
|
||||
Drop-in replacement for stream_chat_completion with improved reliability.
|
||||
"""
|
||||
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
f"Session {session_id} not found. Please create a new session first."
|
||||
)
|
||||
|
||||
if message:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="user" if is_user_message else "assistant", content=message
|
||||
)
|
||||
)
|
||||
if is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id, session_id=session_id, message_length=len(message)
|
||||
)
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Generate title for new sessions (first user message)
|
||||
if is_user_message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
if len(user_messages) == 1:
|
||||
first_message = user_messages[0].content or message or ""
|
||||
if first_message:
|
||||
task = asyncio.create_task(
|
||||
_update_title_async(session_id, first_message, user_id)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# Build system prompt (reuses non-SDK path with Langfuse support)
|
||||
has_history = len(session.messages) > 1
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=has_history
|
||||
)
|
||||
system_prompt += _SDK_TOOL_SUPPLEMENT
|
||||
message_id = str(uuid.uuid4())
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
yield StreamStart(messageId=message_id, taskId=task_id)
|
||||
|
||||
stream_completed = False
|
||||
# Initialise sdk_cwd before the try so the finally can reference it
|
||||
# even if _make_sdk_cwd raises (in that case it stays as "").
|
||||
sdk_cwd = ""
|
||||
use_resume = False
|
||||
|
||||
try:
|
||||
# Use a session-specific temp dir to avoid cleanup race conditions
|
||||
# between concurrent sessions.
|
||||
sdk_cwd = _make_sdk_cwd(session_id)
|
||||
os.makedirs(sdk_cwd, exist_ok=True)
|
||||
|
||||
set_execution_context(
|
||||
user_id,
|
||||
session,
|
||||
long_running_callback=_build_long_running_callback(user_id),
|
||||
)
|
||||
try:
|
||||
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||
|
||||
# Fail fast when no API credentials are available at all
|
||||
sdk_env = _build_sdk_env()
|
||||
if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"):
|
||||
raise RuntimeError(
|
||||
"No API key configured. Set OPEN_ROUTER_API_KEY "
|
||||
"(or CHAT_API_KEY) for OpenRouter routing, "
|
||||
"or ANTHROPIC_API_KEY for direct Anthropic access."
|
||||
)
|
||||
|
||||
mcp_server = create_copilot_mcp_server()
|
||||
|
||||
sdk_model = _resolve_sdk_model()
|
||||
|
||||
# --- Transcript capture via Stop hook ---
|
||||
captured_transcript = CapturedTranscript()
|
||||
|
||||
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
|
||||
captured_transcript.path = transcript_path
|
||||
captured_transcript.sdk_session_id = sdk_session_id
|
||||
|
||||
security_hooks = create_security_hooks(
|
||||
user_id,
|
||||
sdk_cwd=sdk_cwd,
|
||||
max_subtasks=config.claude_agent_max_subtasks,
|
||||
on_stop=_on_stop if config.claude_agent_use_resume else None,
|
||||
)
|
||||
|
||||
# --- Resume strategy: download transcript from bucket ---
|
||||
resume_file: str | None = None
|
||||
use_resume = False
|
||||
|
||||
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
transcript_content = await download_transcript(user_id, session_id)
|
||||
if transcript_content and validate_transcript(transcript_content):
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
transcript_content, session_id, sdk_cwd
|
||||
)
|
||||
if resume_file:
|
||||
use_resume = True
|
||||
logger.info(
|
||||
f"[SDK] Using --resume with transcript "
|
||||
f"({len(transcript_content)} bytes)"
|
||||
)
|
||||
|
||||
sdk_options_kwargs: dict[str, Any] = {
|
||||
"system_prompt": system_prompt,
|
||||
"mcp_servers": {"copilot": mcp_server},
|
||||
"allowed_tools": COPILOT_TOOL_NAMES,
|
||||
"disallowed_tools": ["Bash"],
|
||||
"hooks": security_hooks,
|
||||
"cwd": sdk_cwd,
|
||||
"max_buffer_size": config.claude_agent_max_buffer_size,
|
||||
}
|
||||
if sdk_env:
|
||||
sdk_options_kwargs["model"] = sdk_model
|
||||
sdk_options_kwargs["env"] = sdk_env
|
||||
if use_resume and resume_file:
|
||||
sdk_options_kwargs["resume"] = resume_file
|
||||
|
||||
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type]
|
||||
|
||||
adapter = SDKResponseAdapter(message_id=message_id)
|
||||
adapter.set_task_id(task_id)
|
||||
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
current_message = message or ""
|
||||
if not current_message and session.messages:
|
||||
last_user = [m for m in session.messages if m.role == "user"]
|
||||
if last_user:
|
||||
current_message = last_user[-1].content or ""
|
||||
|
||||
if not current_message.strip():
|
||||
yield StreamError(
|
||||
errorText="Message cannot be empty.",
|
||||
code="empty_prompt",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
# Build query: with --resume the CLI already has full
|
||||
# context, so we only send the new message. Without
|
||||
# resume, compress history into a context prefix.
|
||||
query_message = current_message
|
||||
if not use_resume and len(session.messages) > 1:
|
||||
logger.warning(
|
||||
f"[SDK] Using compression fallback for session "
|
||||
f"{session_id} ({len(session.messages)} messages) — "
|
||||
f"no transcript available for --resume"
|
||||
)
|
||||
compressed = await _compress_conversation_history(session)
|
||||
history_context = _format_conversation_context(compressed)
|
||||
if history_context:
|
||||
query_message = (
|
||||
f"{history_context}\n\n"
|
||||
f"Now, the user says:\n{current_message}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[SDK] Sending query ({len(session.messages)} msgs in session)"
|
||||
)
|
||||
logger.debug(f"[SDK] Query preview: {current_message[:80]!r}")
|
||||
await client.query(query_message, session_id=session_id)
|
||||
|
||||
assistant_response = ChatMessage(role="assistant", content="")
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False
|
||||
|
||||
async for sdk_msg in client.receive_messages():
|
||||
logger.debug(
|
||||
f"[SDK] Received: {type(sdk_msg).__name__} "
|
||||
f"{getattr(sdk_msg, 'subtype', '')}"
|
||||
)
|
||||
for response in adapter.convert_message(sdk_msg):
|
||||
if isinstance(response, StreamStart):
|
||||
continue
|
||||
|
||||
yield response
|
||||
|
||||
if isinstance(response, StreamTextDelta):
|
||||
delta = response.delta or ""
|
||||
# After tool results, start a new assistant
|
||||
# message for the post-tool text.
|
||||
if has_tool_results and has_appended_assistant:
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant", content=delta
|
||||
)
|
||||
accumulated_tool_calls = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
else:
|
||||
assistant_response.content = (
|
||||
assistant_response.content or ""
|
||||
) + delta
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
elif isinstance(response, StreamToolInputAvailable):
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": response.toolCallId,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": response.toolName,
|
||||
"arguments": json.dumps(response.input or {}),
|
||||
},
|
||||
}
|
||||
)
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
elif isinstance(response, StreamToolOutputAvailable):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=(
|
||||
response.output
|
||||
if isinstance(response.output, str)
|
||||
else str(response.output)
|
||||
),
|
||||
tool_call_id=response.toolCallId,
|
||||
)
|
||||
)
|
||||
has_tool_results = True
|
||||
|
||||
elif isinstance(response, StreamFinish):
|
||||
stream_completed = True
|
||||
|
||||
if stream_completed:
|
||||
break
|
||||
|
||||
if (
|
||||
assistant_response.content or assistant_response.tool_calls
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
# --- Capture transcript while CLI is still alive ---
|
||||
# Must happen INSIDE async with: close() sends SIGTERM
|
||||
# which kills the CLI before it can flush the JSONL.
|
||||
if (
|
||||
config.claude_agent_use_resume
|
||||
and user_id
|
||||
and captured_transcript.available
|
||||
):
|
||||
# Give CLI time to flush JSONL writes before we read
|
||||
await asyncio.sleep(0.5)
|
||||
raw_transcript = read_transcript_file(captured_transcript.path)
|
||||
if raw_transcript:
|
||||
task = asyncio.create_task(
|
||||
_upload_transcript_bg(user_id, session_id, raw_transcript)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
else:
|
||||
logger.debug("[SDK] Stop hook fired but transcript not usable")
|
||||
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
"claude-agent-sdk is not installed. "
|
||||
"Disable SDK mode (CHAT_USE_CLAUDE_AGENT_SDK=false) "
|
||||
"to use the OpenAI-compatible fallback."
|
||||
)
|
||||
|
||||
await upsert_chat_session(session)
|
||||
logger.debug(
|
||||
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
|
||||
)
|
||||
if not stream_completed:
|
||||
yield StreamFinish()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SDK] Error: {e}", exc_info=True)
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception as save_err:
|
||||
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="sdk_error",
|
||||
)
|
||||
yield StreamFinish()
|
||||
finally:
|
||||
if sdk_cwd:
|
||||
_cleanup_sdk_tool_results(sdk_cwd)
|
||||
|
||||
|
||||
async def _upload_transcript_bg(
|
||||
user_id: str, session_id: str, raw_content: str
|
||||
) -> None:
|
||||
"""Background task to strip progress entries and upload transcript."""
|
||||
try:
|
||||
await upload_transcript(user_id, session_id, raw_content)
|
||||
except Exception as e:
|
||||
logger.error(f"[SDK] Failed to upload transcript for {session_id}: {e}")
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None = None
|
||||
) -> None:
|
||||
"""Background task to update session title."""
|
||||
try:
|
||||
title = await _generate_session_title(
|
||||
message, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if title:
|
||||
await update_session_title(session_id, title)
|
||||
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[SDK] Failed to update session title: {e}")
|
||||
@@ -0,0 +1,325 @@
|
||||
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
|
||||
|
||||
This module provides the adapter layer that converts existing BaseTool implementations
|
||||
into in-process MCP tools that can be used with the Claude Agent SDK.
|
||||
|
||||
Long-running tools (``is_long_running=True``) are delegated to the non-SDK
|
||||
background infrastructure (stream_registry, Redis persistence, SSE reconnection)
|
||||
via a callback provided by the service layer. This avoids wasteful SDK polling
|
||||
and makes results survive page refreshes.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextvars import ContextVar
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools import TOOL_REGISTRY
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
|
||||
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
|
||||
# in the path — prevents reading settings, credentials, or other sensitive files.
|
||||
_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/")
|
||||
|
||||
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
|
||||
MCP_SERVER_NAME = "copilot"
|
||||
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
||||
|
||||
# Context variables to pass user/session info to tool execution
|
||||
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
"current_session", default=None
|
||||
)
|
||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar(
|
||||
"pending_tool_outputs", default=None # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Callback type for delegating long-running tools to the non-SDK infrastructure.
|
||||
# Args: (tool_name, arguments, session) → MCP-formatted response dict.
|
||||
LongRunningCallback = Callable[
|
||||
[str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]]
|
||||
]
|
||||
|
||||
# ContextVar so the service layer can inject the callback per-request.
|
||||
_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar(
|
||||
"long_running_callback", default=None
|
||||
)
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
long_running_callback: LongRunningCallback | None = None,
|
||||
) -> None:
|
||||
"""Set the execution context for tool calls.
|
||||
|
||||
This must be called before streaming begins to ensure tools have access
|
||||
to user_id and session information.
|
||||
|
||||
Args:
|
||||
user_id: Current user's ID.
|
||||
session: Current chat session.
|
||||
long_running_callback: Optional callback to delegate long-running tools
|
||||
to the non-SDK background infrastructure (stream_registry + Redis).
|
||||
"""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_pending_tool_outputs.set({})
|
||||
_long_running_callback.set(long_running_callback)
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
"""Get the current execution context."""
|
||||
return (
|
||||
_current_user_id.get(),
|
||||
_current_session.get(),
|
||||
)
|
||||
|
||||
|
||||
def pop_pending_tool_output(tool_name: str) -> str | None:
|
||||
"""Pop and return the stashed full output for *tool_name*.
|
||||
|
||||
The SDK CLI may truncate large tool results (writing them to disk and
|
||||
replacing the content with a file reference). This stash keeps the
|
||||
original MCP output so the response adapter can forward it to the
|
||||
frontend for proper widget rendering.
|
||||
|
||||
Returns ``None`` if nothing was stashed for *tool_name*.
|
||||
"""
|
||||
pending = _pending_tool_outputs.get(None)
|
||||
if pending is None:
|
||||
return None
|
||||
return pending.pop(tool_name, None)
|
||||
|
||||
|
||||
async def _execute_tool_sync(
|
||||
base_tool: BaseTool,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
args: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a tool synchronously and return MCP-formatted response."""
|
||||
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
||||
result = await base_tool.execute(
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
tool_call_id=effective_id,
|
||||
**args,
|
||||
)
|
||||
|
||||
text = (
|
||||
result.output if isinstance(result.output, str) else json.dumps(result.output)
|
||||
)
|
||||
|
||||
# Stash the full output before the SDK potentially truncates it.
|
||||
pending = _pending_tool_outputs.get(None)
|
||||
if pending is not None:
|
||||
pending[base_tool.name] = text
|
||||
|
||||
return {
|
||||
"content": [{"type": "text", "text": text}],
|
||||
"isError": not result.success,
|
||||
}
|
||||
|
||||
|
||||
def _mcp_error(message: str) -> dict[str, Any]:
|
||||
return {
|
||||
"content": [
|
||||
{"type": "text", "text": json.dumps({"error": message, "type": "error"})}
|
||||
],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
|
||||
def create_tool_handler(base_tool: BaseTool):
|
||||
"""Create an async handler function for a BaseTool.
|
||||
|
||||
This wraps the existing BaseTool._execute method to be compatible
|
||||
with the Claude Agent SDK MCP tool format.
|
||||
|
||||
Long-running tools (``is_long_running=True``) are delegated to the
|
||||
non-SDK background infrastructure via a callback set in the execution
|
||||
context. The callback persists the operation in Redis (stream_registry)
|
||||
so results survive page refreshes and pod restarts.
|
||||
"""
|
||||
|
||||
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the wrapped tool and return MCP-formatted response."""
|
||||
user_id, session = get_execution_context()
|
||||
|
||||
if session is None:
|
||||
return _mcp_error("No session context available")
|
||||
|
||||
# --- Long-running: delegate to non-SDK background infrastructure ---
|
||||
if base_tool.is_long_running:
|
||||
callback = _long_running_callback.get(None)
|
||||
if callback:
|
||||
try:
|
||||
return await callback(base_tool.name, args, session)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Long-running callback failed for {base_tool.name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return _mcp_error(f"Failed to start {base_tool.name}: {e}")
|
||||
# No callback — fall through to synchronous execution
|
||||
logger.warning(
|
||||
f"[SDK] No long-running callback for {base_tool.name}, "
|
||||
f"executing synchronously (may block)"
|
||||
)
|
||||
|
||||
# --- Normal (fast) tool: execute synchronously ---
|
||||
try:
|
||||
return await _execute_tool_sync(base_tool, user_id, session, args)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
|
||||
|
||||
return tool_handler
|
||||
|
||||
|
||||
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
|
||||
"""Build a JSON Schema input schema for a tool."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": base_tool.parameters.get("properties", {}),
|
||||
"required": base_tool.parameters.get("required", []),
|
||||
}
|
||||
|
||||
|
||||
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Read a file with optional offset/limit. Restricted to SDK working directory.
|
||||
|
||||
After reading, the file is deleted to prevent accumulation in long-running pods.
|
||||
"""
|
||||
file_path = args.get("file_path", "")
|
||||
offset = args.get("offset", 0)
|
||||
limit = args.get("limit", 2000)
|
||||
|
||||
# Security: only allow reads under ~/.claude/projects/**/tool-results/
|
||||
real_path = os.path.realpath(file_path)
|
||||
if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
try:
|
||||
with open(real_path) as f:
|
||||
lines = f.readlines()
|
||||
selected = lines[offset : offset + limit]
|
||||
content = "".join(selected)
|
||||
# Clean up to prevent accumulation in long-running pods
|
||||
try:
|
||||
os.remove(real_path)
|
||||
except OSError:
|
||||
pass
|
||||
return {"content": [{"type": "text", "text": content}], "isError": False}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
|
||||
"isError": True,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
|
||||
_READ_TOOL_NAME = "Read"
|
||||
_READ_TOOL_DESCRIPTION = (
|
||||
"Read a file from the local filesystem. "
|
||||
"Use offset and limit to read specific line ranges for large files."
|
||||
)
|
||||
_READ_TOOL_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The absolute path to the file to read",
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (0-indexed). Default: 0",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to read. Default: 2000",
|
||||
},
|
||||
},
|
||||
"required": ["file_path"],
|
||||
}
|
||||
|
||||
|
||||
# Create the MCP server configuration
|
||||
def create_copilot_mcp_server():
|
||||
"""Create an in-process MCP server configuration for CoPilot tools.
|
||||
|
||||
This can be passed to ClaudeAgentOptions.mcp_servers.
|
||||
|
||||
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
|
||||
package being available. This function returns the configuration that
|
||||
can be used with the SDK.
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
# Create decorated tool functions
|
||||
sdk_tools = []
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handler = create_tool_handler(base_tool)
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
_build_input_schema(base_tool),
|
||||
)(handler)
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# Add the Read tool so the SDK can read back oversized tool results
|
||||
read_tool = tool(
|
||||
_READ_TOOL_NAME,
|
||||
_READ_TOOL_DESCRIPTION,
|
||||
_READ_TOOL_SCHEMA,
|
||||
)(_read_file_handler)
|
||||
sdk_tools.append(read_tool)
|
||||
|
||||
server = create_sdk_mcp_server(
|
||||
name=MCP_SERVER_NAME,
|
||||
version="1.0.0",
|
||||
tools=sdk_tools,
|
||||
)
|
||||
|
||||
return server
|
||||
|
||||
except ImportError:
|
||||
# Let ImportError propagate so service.py handles the fallback
|
||||
raise
|
||||
|
||||
|
||||
# SDK built-in tools allowed within the workspace directory.
|
||||
# Security hooks validate that file paths stay within sdk_cwd.
|
||||
# Bash is NOT included — use the sandboxed MCP bash_exec tool instead,
|
||||
# which provides kernel-level network isolation via unshare --net.
|
||||
# Task allows spawning sub-agents (rate-limited by security hooks).
|
||||
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task"]
|
||||
|
||||
# List of tool names for allowed_tools configuration
|
||||
# Include MCP tools, the MCP Read tool for oversized results,
|
||||
# and SDK built-in file tools for workspace operations.
|
||||
COPILOT_TOOL_NAMES = [
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||
*_SDK_BUILTIN_TOOLS,
|
||||
]
|
||||
@@ -0,0 +1,355 @@
|
||||
"""JSONL transcript management for stateless multi-turn resume.
|
||||
|
||||
The Claude Code CLI persists conversations as JSONL files (one JSON object per
|
||||
line). When the SDK's ``Stop`` hook fires we read this file, strip bloat
|
||||
(progress entries, metadata), and upload the result to bucket storage. On the
|
||||
next turn we download the transcript, write it to a temp file, and pass
|
||||
``--resume`` so the CLI can reconstruct the full conversation.
|
||||
|
||||
Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
|
||||
filesystem for self-hosted) — no DB column needed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
|
||||
_SAFE_ID_RE = re.compile(r"[^0-9a-fA-F-]")
|
||||
|
||||
# Entry types that can be safely removed from the transcript without breaking
|
||||
# the parentUuid conversation tree that ``--resume`` relies on.
|
||||
# - progress: UI progress ticks, no message content (avg 97KB for agent_progress)
|
||||
# - file-history-snapshot: undo tracking metadata
|
||||
# - queue-operation: internal queue bookkeeping
|
||||
# - summary: session summaries
|
||||
# - pr-link: PR link metadata
|
||||
STRIPPABLE_TYPES = frozenset(
|
||||
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
|
||||
)
|
||||
|
||||
# Workspace storage constants — deterministic path from session_id.
|
||||
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Progress stripping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def strip_progress_entries(content: str) -> str:
|
||||
"""Remove progress/metadata entries from a JSONL transcript.
|
||||
|
||||
Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents
|
||||
any remaining child entries so the ``parentUuid`` chain stays intact.
|
||||
Typically reduces transcript size by ~30%.
|
||||
"""
|
||||
lines = content.strip().split("\n")
|
||||
|
||||
entries: list[dict] = []
|
||||
for line in lines:
|
||||
try:
|
||||
entries.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
# Keep unparseable lines as-is (safety)
|
||||
entries.append({"_raw": line})
|
||||
|
||||
stripped_uuids: set[str] = set()
|
||||
uuid_to_parent: dict[str, str] = {}
|
||||
kept: list[dict] = []
|
||||
|
||||
for entry in entries:
|
||||
if "_raw" in entry:
|
||||
kept.append(entry)
|
||||
continue
|
||||
uid = entry.get("uuid", "")
|
||||
parent = entry.get("parentUuid", "")
|
||||
entry_type = entry.get("type", "")
|
||||
|
||||
if uid:
|
||||
uuid_to_parent[uid] = parent
|
||||
|
||||
if entry_type in STRIPPABLE_TYPES:
|
||||
if uid:
|
||||
stripped_uuids.add(uid)
|
||||
else:
|
||||
kept.append(entry)
|
||||
|
||||
# Reparent: walk up chain through stripped entries to find surviving ancestor
|
||||
for entry in kept:
|
||||
if "_raw" in entry:
|
||||
continue
|
||||
parent = entry.get("parentUuid", "")
|
||||
original_parent = parent
|
||||
while parent in stripped_uuids:
|
||||
parent = uuid_to_parent.get(parent, "")
|
||||
if parent != original_parent:
|
||||
entry["parentUuid"] = parent
|
||||
|
||||
result_lines: list[str] = []
|
||||
for entry in kept:
|
||||
if "_raw" in entry:
|
||||
result_lines.append(entry["_raw"])
|
||||
else:
|
||||
result_lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
|
||||
return "\n".join(result_lines) + "\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local file I/O (read from CLI's JSONL, write temp file for --resume)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def read_transcript_file(transcript_path: str) -> str | None:
|
||||
"""Read a JSONL transcript file from disk.
|
||||
|
||||
Returns the raw JSONL content, or ``None`` if the file is missing, empty,
|
||||
or only contains metadata (≤2 lines with no conversation messages).
|
||||
"""
|
||||
if not transcript_path or not os.path.isfile(transcript_path):
|
||||
logger.debug(f"[Transcript] File not found: {transcript_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(transcript_path) as f:
|
||||
content = f.read()
|
||||
|
||||
if not content.strip():
|
||||
logger.debug(f"[Transcript] Empty file: {transcript_path}")
|
||||
return None
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
if len(lines) < 2:
|
||||
# Metadata-only files have 1 line (single queue-operation or snapshot).
|
||||
logger.debug(
|
||||
f"[Transcript] Too few lines ({len(lines)}): {transcript_path}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Quick structural validation — parse first and last lines.
|
||||
json.loads(lines[0])
|
||||
json.loads(lines[-1])
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Read {len(lines)} lines, "
|
||||
f"{len(content)} bytes from {transcript_path}"
|
||||
)
|
||||
return content
|
||||
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
"""Sanitize an ID for safe use in file paths.
|
||||
|
||||
Session/user IDs are expected to be UUIDs (hex + hyphens). Strip
|
||||
everything else and truncate to *max_len* so the result cannot introduce
|
||||
path separators or other special characters.
|
||||
"""
|
||||
cleaned = _SAFE_ID_RE.sub("", raw_id or "")[:max_len]
|
||||
return cleaned or "unknown"
|
||||
|
||||
|
||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||
|
||||
|
||||
def write_transcript_to_tempfile(
|
||||
transcript_content: str,
|
||||
session_id: str,
|
||||
cwd: str,
|
||||
) -> str | None:
|
||||
"""Write JSONL transcript to a temp file inside *cwd* for ``--resume``.
|
||||
|
||||
The file lives in the session working directory so it is cleaned up
|
||||
automatically when the session ends.
|
||||
|
||||
Returns the absolute path to the file, or ``None`` on failure.
|
||||
"""
|
||||
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
|
||||
real_cwd = os.path.realpath(cwd)
|
||||
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
|
||||
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
|
||||
return None
|
||||
|
||||
try:
|
||||
os.makedirs(real_cwd, exist_ok=True)
|
||||
safe_id = _sanitize_id(session_id, max_len=8)
|
||||
jsonl_path = os.path.realpath(
|
||||
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
|
||||
)
|
||||
if not jsonl_path.startswith(real_cwd):
|
||||
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
|
||||
return None
|
||||
|
||||
with open(jsonl_path, "w") as f:
|
||||
f.write(transcript_content)
|
||||
|
||||
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
||||
return jsonl_path
|
||||
|
||||
except OSError as e:
|
||||
logger.warning(f"[Transcript] Failed to write resume file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def validate_transcript(content: str | None) -> bool:
|
||||
"""Check that a transcript has actual conversation messages.
|
||||
|
||||
A valid transcript for resume needs at least one user message and one
|
||||
assistant message (not just queue-operation / file-history-snapshot
|
||||
metadata).
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return False
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
if len(lines) < 2:
|
||||
return False
|
||||
|
||||
has_user = False
|
||||
has_assistant = False
|
||||
|
||||
for line in lines:
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
msg_type = entry.get("type")
|
||||
if msg_type == "user":
|
||||
has_user = True
|
||||
elif msg_type == "assistant":
|
||||
has_assistant = True
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
|
||||
return has_user and has_assistant
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bucket storage (GCS / local via WorkspaceStorageBackend)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
"""Return (workspace_id, file_id, filename) for a session's transcript.
|
||||
|
||||
Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl``
|
||||
IDs are sanitized to hex+hyphen to prevent path traversal.
|
||||
"""
|
||||
return (
|
||||
TRANSCRIPT_STORAGE_PREFIX,
|
||||
_sanitize_id(user_id),
|
||||
f"{_sanitize_id(session_id)}.jsonl",
|
||||
)
|
||||
|
||||
|
||||
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
"""Build the full storage path string that ``retrieve()`` expects.
|
||||
|
||||
``store()`` returns a path like ``gcs://bucket/workspaces/...`` or
|
||||
``local://workspace_id/file_id/filename``. Since we use deterministic
|
||||
arguments we can reconstruct the same path for download/delete without
|
||||
having stored the return value.
|
||||
"""
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||
|
||||
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
||||
|
||||
if isinstance(backend, GCSWorkspaceStorage):
|
||||
blob = f"workspaces/{wid}/{fid}/{fname}"
|
||||
return f"gcs://{backend.bucket_name}/{blob}"
|
||||
else:
|
||||
# LocalWorkspaceStorage returns local://{relative_path}
|
||||
return f"local://{wid}/{fid}/{fname}"
|
||||
|
||||
|
||||
async def upload_transcript(user_id: str, session_id: str, content: str) -> None:
|
||||
"""Strip progress entries and upload transcript to bucket storage.
|
||||
|
||||
Safety: only overwrites when the new (stripped) transcript is larger than
|
||||
what is already stored. Since JSONL is append-only, the latest transcript
|
||||
is always the longest. This prevents a slow/stale background task from
|
||||
clobbering a newer upload from a concurrent turn.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
stripped = strip_progress_entries(content)
|
||||
if not validate_transcript(stripped):
|
||||
logger.warning(
|
||||
f"[Transcript] Skipping upload — stripped content is not a valid "
|
||||
f"transcript for session {session_id}"
|
||||
)
|
||||
return
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
||||
encoded = stripped.encode("utf-8")
|
||||
new_size = len(encoded)
|
||||
|
||||
# Check existing transcript size to avoid overwriting newer with older
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
try:
|
||||
existing = await storage.retrieve(path)
|
||||
if len(existing) >= new_size:
|
||||
logger.info(
|
||||
f"[Transcript] Skipping upload — existing transcript "
|
||||
f"({len(existing)}B) >= new ({new_size}B) for session "
|
||||
f"{session_id}"
|
||||
)
|
||||
return
|
||||
except (FileNotFoundError, Exception):
|
||||
pass # No existing transcript or retrieval error — proceed with upload
|
||||
|
||||
await storage.store(
|
||||
workspace_id=wid,
|
||||
file_id=fid,
|
||||
filename=fname,
|
||||
content=encoded,
|
||||
)
|
||||
logger.info(
|
||||
f"[Transcript] Uploaded {new_size} bytes "
|
||||
f"(stripped from {len(content)}) for session {session_id}"
|
||||
)
|
||||
|
||||
|
||||
async def download_transcript(user_id: str, session_id: str) -> str | None:
|
||||
"""Download transcript from bucket storage.
|
||||
|
||||
Returns the JSONL content string, or ``None`` if not found.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
try:
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
logger.info(
|
||||
f"[Transcript] Downloaded {len(content)} bytes for session {session_id}"
|
||||
)
|
||||
return content
|
||||
except FileNotFoundError:
|
||||
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transcript] Failed to download transcript: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
"""Delete transcript from bucket storage (e.g. after resume failure)."""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
try:
|
||||
await storage.delete(path)
|
||||
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transcript] Failed to delete transcript: {e}")
|
||||
@@ -27,7 +27,6 @@ from openai.types.chat import (
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
|
||||
from backend.data.db_accessors import chat_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import (
|
||||
format_understanding_for_prompt,
|
||||
@@ -36,6 +35,7 @@ from backend.data.understanding import (
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
|
||||
from . import db as chat_db
|
||||
from . import stream_registry
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
@@ -245,12 +245,16 @@ async def _get_system_prompt_template(context: str) -> str:
|
||||
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
||||
|
||||
|
||||
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||
async def _build_system_prompt(
|
||||
user_id: str | None, has_conversation_history: bool = False
|
||||
) -> tuple[str, Any]:
|
||||
"""Build the full system prompt including business understanding if available.
|
||||
|
||||
Args:
|
||||
user_id: The user ID for fetching business understanding
|
||||
If "default" and this is the user's first session, will use "onboarding" instead.
|
||||
user_id: The user ID for fetching business understanding.
|
||||
has_conversation_history: Whether there's existing conversation history.
|
||||
If True, we don't tell the model to greet/introduce (since they're
|
||||
already in a conversation).
|
||||
|
||||
Returns:
|
||||
Tuple of (compiled prompt string, business understanding object)
|
||||
@@ -266,6 +270,8 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||
|
||||
if understanding:
|
||||
context = format_understanding_for_prompt(understanding)
|
||||
elif has_conversation_history:
|
||||
context = "No prior understanding saved yet. Continue the existing conversation naturally."
|
||||
else:
|
||||
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||
|
||||
@@ -374,7 +380,6 @@ async def stream_chat_completion(
|
||||
|
||||
Raises:
|
||||
NotFoundError: If session_id is invalid
|
||||
ValueError: If max_context_messages is exceeded
|
||||
|
||||
"""
|
||||
completion_start = time.monotonic()
|
||||
@@ -459,8 +464,9 @@ async def stream_chat_completion(
|
||||
|
||||
# Generate title for new sessions on first user message (non-blocking)
|
||||
# Check: is_user_message, no title yet, and this is the first user message
|
||||
if is_user_message and message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
first_user_msg = message or (user_messages[0].content if user_messages else None)
|
||||
if is_user_message and first_user_msg and not session.title:
|
||||
if len(user_messages) == 1:
|
||||
# First user message - generate title in background
|
||||
import asyncio
|
||||
@@ -468,7 +474,7 @@ async def stream_chat_completion(
|
||||
# Capture only the values we need (not the session object) to avoid
|
||||
# stale data issues when the main flow modifies the session
|
||||
captured_session_id = session_id
|
||||
captured_message = message
|
||||
captured_message = first_user_msg
|
||||
captured_user_id = user_id
|
||||
|
||||
async def _update_title():
|
||||
@@ -1237,7 +1243,7 @@ async def _stream_chat_chunks(
|
||||
|
||||
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
|
||||
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
|
||||
f"session={session.session_id}, user={session.user_id}",
|
||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||
)
|
||||
@@ -1744,7 +1750,7 @@ async def _update_pending_operation(
|
||||
This is called by background tasks when long-running operations complete.
|
||||
"""
|
||||
# Update the message in database
|
||||
updated = await chat_db().update_tool_message_content(
|
||||
updated = await chat_db.update_tool_message_content(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
new_content=result,
|
||||
@@ -0,0 +1,178 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
|
||||
from . import service as chat_service
|
||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||
from .response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamTextDelta,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from .sdk import service as sdk_service
|
||||
from .sdk.transcript import download_transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
assistant_message = ""
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id, "Hello, how are you?", user_id=session.user_id
|
||||
):
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
assistant_message += chunk.delta
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert assistant_message, "Assistant message is empty"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
had_tool_calls = False
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id,
|
||||
"Please find me an agent that can help me with my business. Use the query 'moneny printing agent'",
|
||||
user_id=session.user_id,
|
||||
):
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
if isinstance(chunk, StreamToolOutputAvailable):
|
||||
had_tool_calls = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert had_tool_calls, "Tool calls did not occur"
|
||||
session = await get_chat_session(session.session_id)
|
||||
assert session, "Session not found"
|
||||
assert session.usage, "Usage is empty"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
"""Test that the SDK --resume path captures and uses transcripts across turns.
|
||||
|
||||
Turn 1: Send a message containing a unique keyword.
|
||||
Turn 2: Ask the model to recall that keyword — proving the transcript was
|
||||
persisted and restored via --resume.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
from .config import ChatConfig
|
||||
|
||||
cfg = ChatConfig()
|
||||
if not cfg.claude_agent_use_resume:
|
||||
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# --- Turn 1: send a message with a unique keyword ---
|
||||
keyword = "ZEPHYR42"
|
||||
turn1_msg = (
|
||||
f"Please remember this special keyword: {keyword}. "
|
||||
"Just confirm you've noted it, keep your response brief."
|
||||
)
|
||||
turn1_text = ""
|
||||
turn1_errors: list[str] = []
|
||||
turn1_ended = False
|
||||
|
||||
async for chunk in sdk_service.stream_chat_completion_sdk(
|
||||
session.session_id,
|
||||
turn1_msg,
|
||||
user_id=test_user_id,
|
||||
):
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
turn1_text += chunk.delta
|
||||
elif isinstance(chunk, StreamError):
|
||||
turn1_errors.append(chunk.errorText)
|
||||
elif isinstance(chunk, StreamFinish):
|
||||
turn1_ended = True
|
||||
|
||||
assert turn1_ended, "Turn 1 did not finish"
|
||||
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
|
||||
assert turn1_text, "Turn 1 produced no text"
|
||||
|
||||
# Wait for background upload task to complete (retry up to 5s)
|
||||
transcript = None
|
||||
for _ in range(10):
|
||||
await asyncio.sleep(0.5)
|
||||
transcript = await download_transcript(test_user_id, session.session_id)
|
||||
if transcript:
|
||||
break
|
||||
assert transcript, (
|
||||
"Transcript was not uploaded to bucket after turn 1 — "
|
||||
"Stop hook may not have fired or transcript was too small"
|
||||
)
|
||||
logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes")
|
||||
|
||||
# Reload session for turn 2
|
||||
session = await get_chat_session(session.session_id, test_user_id)
|
||||
assert session, "Session not found after turn 1"
|
||||
|
||||
# --- Turn 2: ask model to recall the keyword ---
|
||||
turn2_msg = "What was the special keyword I asked you to remember?"
|
||||
turn2_text = ""
|
||||
turn2_errors: list[str] = []
|
||||
turn2_ended = False
|
||||
|
||||
async for chunk in sdk_service.stream_chat_completion_sdk(
|
||||
session.session_id,
|
||||
turn2_msg,
|
||||
user_id=test_user_id,
|
||||
session=session,
|
||||
):
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
turn2_text += chunk.delta
|
||||
elif isinstance(chunk, StreamError):
|
||||
turn2_errors.append(chunk.errorText)
|
||||
elif isinstance(chunk, StreamFinish):
|
||||
turn2_ended = True
|
||||
|
||||
assert turn2_ended, "Turn 2 did not finish"
|
||||
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
|
||||
assert turn2_text, "Turn 2 produced no text"
|
||||
assert keyword in turn2_text, (
|
||||
f"Model did not recall keyword '{keyword}' in turn 2. "
|
||||
f"Response: {turn2_text[:200]}"
|
||||
)
|
||||
logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}")
|
||||
@@ -814,6 +814,28 @@ async def get_active_task_for_session(
|
||||
if task_user_id and user_id != task_user_id:
|
||||
continue
|
||||
|
||||
# Auto-expire stale tasks that exceeded stream_timeout
|
||||
created_at_str = meta.get("created_at", "")
|
||||
if created_at_str:
|
||||
try:
|
||||
created_at = datetime.fromisoformat(created_at_str)
|
||||
age_seconds = (
|
||||
datetime.now(timezone.utc) - created_at
|
||||
).total_seconds()
|
||||
if age_seconds > config.stream_timeout:
|
||||
logger.warning(
|
||||
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
|
||||
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
|
||||
)
|
||||
await mark_task_completed(task_id, "failed")
|
||||
continue
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
|
||||
)
|
||||
|
||||
# Get the last message ID from Redis Stream
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
last_id = "0-0"
|
||||
@@ -3,12 +3,14 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tracking import track_tool_called
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tracking import track_tool_called
|
||||
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .check_operation_status import CheckOperationStatusTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
from .edit_agent import EditAgentTool
|
||||
@@ -19,6 +21,7 @@ from .get_doc_page import GetDocPageTool
|
||||
from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .search_docs import SearchDocsTool
|
||||
from .web_fetch import WebFetchTool
|
||||
from .workspace_files import (
|
||||
DeleteWorkspaceFileTool,
|
||||
ListWorkspaceFilesTool,
|
||||
@@ -27,7 +30,7 @@ from .workspace_files import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -43,9 +46,14 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"run_agent": RunAgentTool(),
|
||||
"run_block": RunBlockTool(),
|
||||
"view_agent_output": AgentOutputTool(),
|
||||
"check_operation_status": CheckOperationStatusTool(),
|
||||
"search_docs": SearchDocsTool(),
|
||||
"get_doc_page": GetDocPageTool(),
|
||||
# Workspace tools for CoPilot file operations
|
||||
# Web fetch for safe URL retrieval
|
||||
"web_fetch": WebFetchTool(),
|
||||
# Sandboxed code execution (bubblewrap)
|
||||
"bash_exec": BashExecTool(),
|
||||
# Persistent workspace tools (cloud storage, survives across sessions)
|
||||
"list_workspace_files": ListWorkspaceFilesTool(),
|
||||
"read_workspace_file": ReadWorkspaceFileTool(),
|
||||
"write_workspace_file": WriteWorkspaceFileTool(),
|
||||
@@ -6,11 +6,11 @@ import pytest
|
||||
from prisma.types import ProfileCreateInput
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db import prisma
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import APIKeyCredentials
|
||||
@@ -3,9 +3,11 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import understanding_db
|
||||
from backend.data.understanding import BusinessUnderstandingInput
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
||||
@@ -97,9 +99,7 @@ and automations for the user's specific needs."""
|
||||
]
|
||||
|
||||
# Upsert with merge
|
||||
understanding = await understanding_db().upsert_business_understanding(
|
||||
user_id, input_data
|
||||
)
|
||||
understanding = await upsert_business_understanding(user_id, input_data)
|
||||
|
||||
# Build current understanding summary (filter out empty values)
|
||||
current_understanding = {
|
||||
@@ -5,8 +5,9 @@ import re
|
||||
import uuid
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from backend.data.db_accessors import graph_db, library_db, store_db
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .service import (
|
||||
@@ -144,9 +145,8 @@ async def get_library_agent_by_id(
|
||||
Returns:
|
||||
LibraryAgentSummary if found, None otherwise
|
||||
"""
|
||||
db = library_db()
|
||||
try:
|
||||
agent = await db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
@@ -163,7 +163,7 @@ async def get_library_agent_by_id(
|
||||
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
||||
|
||||
try:
|
||||
agent = await db.get_library_agent(agent_id, user_id)
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
@@ -215,7 +215,7 @@ async def get_library_agents_for_generation(
|
||||
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
||||
"""
|
||||
try:
|
||||
response = await library_db().list_library_agents(
|
||||
response = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_query,
|
||||
page=1,
|
||||
@@ -272,7 +272,7 @@ async def search_marketplace_agents_for_generation(
|
||||
List of LibraryAgentSummary with full input/output schemas
|
||||
"""
|
||||
try:
|
||||
response = await store_db().get_store_agents(
|
||||
response = await store_db.get_store_agents(
|
||||
search_query=search_query,
|
||||
page=1,
|
||||
page_size=max_results,
|
||||
@@ -286,7 +286,7 @@ async def search_marketplace_agents_for_generation(
|
||||
return []
|
||||
|
||||
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
|
||||
graphs = await graph_db().get_store_listed_graphs(*graph_ids)
|
||||
graphs = await get_store_listed_graphs(*graph_ids)
|
||||
|
||||
results: list[LibraryAgentSummary] = []
|
||||
for agent in agents_with_graphs:
|
||||
@@ -673,10 +673,9 @@ async def save_agent_to_library(
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
"""
|
||||
graph = json_to_graph(agent_json)
|
||||
db = library_db()
|
||||
if is_update:
|
||||
return await db.update_graph_in_library(graph, user_id)
|
||||
return await db.create_graph_in_library(graph, user_id)
|
||||
return await library_db.update_graph_in_library(graph, user_id)
|
||||
return await library_db.create_graph_in_library(graph, user_id)
|
||||
|
||||
|
||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||
@@ -736,14 +735,12 @@ async def get_agent_as_json(
|
||||
Returns:
|
||||
Agent as JSON dict or None if not found
|
||||
"""
|
||||
db = graph_db()
|
||||
|
||||
graph = await db.get_graph(agent_id, version=None, user_id=user_id)
|
||||
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||
|
||||
if not graph and user_id:
|
||||
try:
|
||||
library_agent = await library_db().get_library_agent(agent_id, user_id)
|
||||
graph = await db.get_graph(
|
||||
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id, version=None, user_id=user_id
|
||||
)
|
||||
except NotFoundError:
|
||||
@@ -7,9 +7,10 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import execution_db, library_db
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||
|
||||
from .base import BaseTool
|
||||
@@ -164,12 +165,10 @@ class AgentOutputTool(BaseTool):
|
||||
Resolve agent from provided identifiers.
|
||||
Returns (library_agent, error_message).
|
||||
"""
|
||||
lib_db = library_db()
|
||||
|
||||
# Priority 1: Exact library agent ID
|
||||
if library_agent_id:
|
||||
try:
|
||||
agent = await lib_db.get_library_agent(library_agent_id, user_id)
|
||||
agent = await library_db.get_library_agent(library_agent_id, user_id)
|
||||
return agent, None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get library agent by ID: {e}")
|
||||
@@ -183,7 +182,7 @@ class AgentOutputTool(BaseTool):
|
||||
return None, f"Agent '{store_slug}' not found in marketplace"
|
||||
|
||||
# Find in user's library by graph_id
|
||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
if not agent:
|
||||
return (
|
||||
None,
|
||||
@@ -195,7 +194,7 @@ class AgentOutputTool(BaseTool):
|
||||
# Priority 3: Fuzzy name search in library
|
||||
if agent_name:
|
||||
try:
|
||||
response = await lib_db.list_library_agents(
|
||||
response = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=agent_name,
|
||||
page_size=5,
|
||||
@@ -229,11 +228,9 @@ class AgentOutputTool(BaseTool):
|
||||
Fetch execution(s) based on filters.
|
||||
Returns (single_execution, available_executions_meta, error_message).
|
||||
"""
|
||||
exec_db = execution_db()
|
||||
|
||||
# If specific execution_id provided, fetch it directly
|
||||
if execution_id:
|
||||
execution = await exec_db.get_graph_execution(
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
@@ -243,7 +240,7 @@ class AgentOutputTool(BaseTool):
|
||||
return execution, [], None
|
||||
|
||||
# Get completed executions with time filters
|
||||
executions = await exec_db.get_graph_executions(
|
||||
executions = await execution_db.get_graph_executions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
@@ -257,7 +254,7 @@ class AgentOutputTool(BaseTool):
|
||||
|
||||
# If only one execution, fetch full details
|
||||
if len(executions) == 1:
|
||||
full_execution = await exec_db.get_graph_execution(
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
@@ -265,7 +262,7 @@ class AgentOutputTool(BaseTool):
|
||||
return full_execution, [], None
|
||||
|
||||
# Multiple executions - return latest with full details, plus list of available
|
||||
full_execution = await exec_db.get_graph_execution(
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
@@ -383,7 +380,7 @@ class AgentOutputTool(BaseTool):
|
||||
and not input_data.store_slug
|
||||
):
|
||||
# Fetch execution directly to get graph_id
|
||||
execution = await execution_db().get_graph_execution(
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=input_data.execution_id,
|
||||
include_node_executions=False,
|
||||
@@ -395,7 +392,7 @@ class AgentOutputTool(BaseTool):
|
||||
)
|
||||
|
||||
# Find library agent by graph_id
|
||||
agent = await library_db().get_library_agent_by_graph_id(
|
||||
agent = await library_db.get_library_agent_by_graph_id(
|
||||
user_id, execution.graph_id
|
||||
)
|
||||
if not agent:
|
||||
@@ -4,7 +4,8 @@ import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from backend.data.db_accessors import library_db, store_db
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .models import (
|
||||
@@ -44,10 +45,8 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
||||
Returns:
|
||||
AgentInfo if found, None otherwise
|
||||
"""
|
||||
lib_db = library_db()
|
||||
|
||||
try:
|
||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
@@ -72,7 +71,7 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
||||
)
|
||||
|
||||
try:
|
||||
agent = await lib_db.get_library_agent(agent_id, user_id)
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
@@ -134,7 +133,7 @@ async def search_agents(
|
||||
try:
|
||||
if source == "marketplace":
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
results = await store_db().get_store_agents(search_query=query, page_size=5)
|
||||
results = await store_db.get_store_agents(search_query=query, page_size=5)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
@@ -160,7 +159,7 @@ async def search_agents(
|
||||
|
||||
if not agents:
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
results = await library_db().list_library_agents(
|
||||
results = await library_db.list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
@@ -5,8 +5,8 @@ from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
"""Bash execution tool — run shell commands in a bubblewrap sandbox.
|
||||
|
||||
Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.).
|
||||
Safety comes from OS-level isolation (bubblewrap): only system dirs visible
|
||||
read-only, writable workspace only, clean env, no network.
|
||||
|
||||
Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not
|
||||
available (e.g. macOS development).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
BashExecResponse,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.api.features.chat.tools.sandbox import (
|
||||
get_workspace_dir,
|
||||
has_full_sandbox,
|
||||
run_sandboxed,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BashExecTool(BaseTool):
|
||||
"""Execute Bash commands in a bubblewrap sandbox."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "bash_exec"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
if not has_full_sandbox():
|
||||
return (
|
||||
"Bash execution is DISABLED — bubblewrap sandbox is not "
|
||||
"available on this platform. Do not call this tool."
|
||||
)
|
||||
return (
|
||||
"Execute a Bash command or script in a bubblewrap sandbox. "
|
||||
"Full Bash scripting is supported (loops, conditionals, pipes, "
|
||||
"functions, etc.). "
|
||||
"The sandbox shares the same working directory as the SDK Read/Write "
|
||||
"tools — files created by either are accessible to both. "
|
||||
"SECURITY: Only system directories (/usr, /bin, /lib, /etc) are "
|
||||
"visible read-only, the per-session workspace is the only writable "
|
||||
"path, environment variables are wiped (no secrets), all network "
|
||||
"access is blocked at the kernel level, and resource limits are "
|
||||
"enforced (max 64 processes, 512MB memory, 50MB file size). "
|
||||
"Application code, configs, and other directories are NOT accessible. "
|
||||
"To fetch web content, use the web_fetch tool instead. "
|
||||
"Execution is killed after the timeout (default 30s, max 120s). "
|
||||
"Returns stdout and stderr. "
|
||||
"Useful for file manipulation, data processing with Unix tools "
|
||||
"(grep, awk, sed, jq, etc.), and running shell scripts."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Bash command or script to execute.",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Max execution time in seconds (default 30, max 120)."
|
||||
),
|
||||
"default": 30,
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not has_full_sandbox():
|
||||
return ErrorResponse(
|
||||
message="bash_exec requires bubblewrap sandbox (Linux only).",
|
||||
error="sandbox_unavailable",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
command: str = (kwargs.get("command") or "").strip()
|
||||
timeout: int = kwargs.get("timeout", 30)
|
||||
|
||||
if not command:
|
||||
return ErrorResponse(
|
||||
message="No command provided.",
|
||||
error="empty_command",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
workspace = get_workspace_dir(session_id or "default")
|
||||
|
||||
stdout, stderr, exit_code, timed_out = await run_sandboxed(
|
||||
command=["bash", "-c", command],
|
||||
cwd=workspace,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return BashExecResponse(
|
||||
message=(
|
||||
"Execution timed out"
|
||||
if timed_out
|
||||
else f"Command executed (exit {exit_code})"
|
||||
),
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
exit_code=exit_code,
|
||||
timed_out=timed_out,
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,127 @@
|
||||
"""CheckOperationStatusTool — query the status of a long-running operation."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
ErrorResponse,
|
||||
ResponseType,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OperationStatusResponse(ToolResponseBase):
|
||||
"""Response for check_operation_status tool."""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STATUS
|
||||
task_id: str
|
||||
operation_id: str
|
||||
status: str # "running", "completed", "failed"
|
||||
tool_name: str | None = None
|
||||
message: str = ""
|
||||
|
||||
|
||||
class CheckOperationStatusTool(BaseTool):
|
||||
"""Check the status of a long-running operation (create_agent, edit_agent, etc.).
|
||||
|
||||
The CoPilot uses this tool to report back to the user whether an
|
||||
operation that was started earlier has completed, failed, or is still
|
||||
running.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "check_operation_status"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Check the current status of a long-running operation such as "
|
||||
"create_agent or edit_agent. Accepts either an operation_id or "
|
||||
"task_id from a previous operation_started response. "
|
||||
"Returns the current status: running, completed, or failed."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The operation_id from an operation_started response."
|
||||
),
|
||||
},
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The task_id from an operation_started response. "
|
||||
"Used as fallback if operation_id is not provided."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
from backend.api.features.chat import stream_registry
|
||||
|
||||
operation_id = (kwargs.get("operation_id") or "").strip()
|
||||
task_id = (kwargs.get("task_id") or "").strip()
|
||||
|
||||
if not operation_id and not task_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide an operation_id or task_id.",
|
||||
error="missing_parameter",
|
||||
)
|
||||
|
||||
task = None
|
||||
if operation_id:
|
||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||
if task is None and task_id:
|
||||
task = await stream_registry.get_task(task_id)
|
||||
|
||||
if task is None:
|
||||
# Task not in Redis — it may have already expired (TTL).
|
||||
# Check conversation history for the result instead.
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Operation not found — it may have already completed and "
|
||||
"expired from the status tracker. Check the conversation "
|
||||
"history for the result."
|
||||
),
|
||||
error="not_found",
|
||||
)
|
||||
|
||||
status_messages = {
|
||||
"running": (
|
||||
f"The {task.tool_name or 'operation'} is still running. "
|
||||
"Please wait for it to complete."
|
||||
),
|
||||
"completed": (
|
||||
f"The {task.tool_name or 'operation'} has completed successfully."
|
||||
),
|
||||
"failed": f"The {task.tool_name or 'operation'} has failed.",
|
||||
}
|
||||
|
||||
return OperationStatusResponse(
|
||||
task_id=task.task_id,
|
||||
operation_id=task.operation_id,
|
||||
status=task.status,
|
||||
tool_name=task.tool_name,
|
||||
message=status_messages.get(task.status, f"Status: {task.status}"),
|
||||
)
|
||||
@@ -3,7 +3,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -3,9 +3,9 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import store_db as get_store_db
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -137,8 +137,6 @@ class CustomizeAgentTool(BaseTool):
|
||||
|
||||
creator_username, agent_slug = parts
|
||||
|
||||
store_db = get_store_db()
|
||||
|
||||
# Fetch the marketplace agent details
|
||||
try:
|
||||
agent_details = await store_db.get_store_agent_details(
|
||||
@@ -3,7 +3,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
@@ -3,17 +3,17 @@ from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.base import BaseTool, ToolResponseBase
|
||||
from backend.copilot.tools.models import (
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
||||
from backend.api.features.chat.tools.models import (
|
||||
BlockInfoSummary,
|
||||
BlockListResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
)
|
||||
from backend.data.db_accessors import search
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import BlockType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -107,7 +107,7 @@ class FindBlockTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Search for blocks using hybrid search
|
||||
results, total = await search().unified_hybrid_search(
|
||||
results, total = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
@@ -146,6 +146,7 @@ class FindBlockTool(BaseTool):
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
description=block.description or "",
|
||||
categories=[c.value for c in block.categories],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -4,13 +4,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.find_block import (
|
||||
from backend.api.features.chat.tools.find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
FindBlockTool,
|
||||
)
|
||||
from backend.copilot.tools.models import BlockListResponse
|
||||
from backend.api.features.chat.tools.models import BlockListResponse
|
||||
from backend.blocks._base import BlockType
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
@@ -84,17 +84,13 @@ class TestFindBlockFiltering:
|
||||
"standard-block-id": standard_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, 2),
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
@@ -132,17 +128,13 @@ class TestFindBlockFiltering:
|
||||
"normal-block-id": normal_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, 2),
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
@@ -4,9 +4,9 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.copilot.tools.models import (
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
DocPageResponse,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
@@ -41,6 +41,12 @@ class ResponseType(str, Enum):
|
||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||
# Input validation
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
# Web fetch
|
||||
WEB_FETCH = "web_fetch"
|
||||
# Code execution
|
||||
BASH_EXEC = "bash_exec"
|
||||
# Operation status check
|
||||
OPERATION_STATUS = "operation_status"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -335,6 +341,19 @@ class BlockInfoSummary(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
categories: list[str]
|
||||
input_schema: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Full JSON schema for block inputs",
|
||||
)
|
||||
output_schema: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Full JSON schema for block outputs",
|
||||
)
|
||||
required_inputs: list[BlockInputFieldInfo] = Field(
|
||||
default_factory=list,
|
||||
description="List of input fields for this block",
|
||||
)
|
||||
|
||||
|
||||
class BlockListResponse(ToolResponseBase):
|
||||
@@ -344,6 +363,10 @@ class BlockListResponse(ToolResponseBase):
|
||||
blocks: list[BlockInfoSummary]
|
||||
count: int
|
||||
query: str
|
||||
usage_hint: str = Field(
|
||||
default="To execute a block, call run_block with block_id set to the block's "
|
||||
"'id' field and input_data containing the fields listed in required_inputs."
|
||||
)
|
||||
|
||||
|
||||
class BlockDetails(BaseModel):
|
||||
@@ -430,3 +453,24 @@ class AsyncProcessingResponse(ToolResponseBase):
|
||||
status: str = "accepted" # Must be "accepted" for detection
|
||||
operation_id: str | None = None
|
||||
task_id: str | None = None
|
||||
|
||||
|
||||
class WebFetchResponse(ToolResponseBase):
|
||||
"""Response for web_fetch tool."""
|
||||
|
||||
type: ResponseType = ResponseType.WEB_FETCH
|
||||
url: str
|
||||
status_code: int
|
||||
content_type: str
|
||||
content: str
|
||||
truncated: bool = False
|
||||
|
||||
|
||||
class BashExecResponse(ToolResponseBase):
|
||||
"""Response for bash_exec tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BASH_EXEC
|
||||
stdout: str
|
||||
stderr: str
|
||||
exit_code: int
|
||||
timed_out: bool = False
|
||||
@@ -5,12 +5,16 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled
|
||||
from backend.data.db_accessors import graph_db, library_db, user_db
|
||||
from backend.api.features.chat.config import ChatConfig
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tracking import (
|
||||
track_agent_run_success,
|
||||
track_agent_scheduled,
|
||||
)
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
@@ -196,7 +200,7 @@ class RunAgentTool(BaseTool):
|
||||
|
||||
# Priority: library_agent_id if provided
|
||||
if has_library_id:
|
||||
library_agent = await library_db().get_library_agent(
|
||||
library_agent = await library_db.get_library_agent(
|
||||
params.library_agent_id, user_id
|
||||
)
|
||||
if not library_agent:
|
||||
@@ -205,7 +209,9 @@ class RunAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
# Get the graph from the library agent
|
||||
graph = await graph_db().get_graph(
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id,
|
||||
library_agent.graph_version,
|
||||
user_id=user_id,
|
||||
@@ -516,7 +522,7 @@ class RunAgentTool(BaseTool):
|
||||
library_agent = await get_or_create_library_agent(graph, user_id)
|
||||
|
||||
# Get user timezone
|
||||
user = await user_db().get_user_by_id(user_id)
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone)
|
||||
|
||||
# Create schedule
|
||||
@@ -7,16 +7,16 @@ from typing import Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.find_block import (
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
)
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
@@ -276,7 +276,7 @@ class RunBlockTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Get or create user's workspace for CoPilot file operations
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
# Generate synthetic IDs for CoPilot context
|
||||
# Each chat session is treated as its own agent with one continuous run
|
||||
@@ -4,14 +4,14 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.models import (
|
||||
from backend.api.features.chat.tools.models import (
|
||||
BlockDetailsResponse,
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
InputValidationErrorResponse,
|
||||
)
|
||||
from backend.copilot.tools.run_block import RunBlockTool
|
||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||
from backend.blocks._base import BlockType
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
@@ -77,7 +77,7 @@ class TestRunBlockFiltering:
|
||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=input_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -103,7 +103,7 @@ class TestRunBlockFiltering:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=smart_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -127,7 +127,7 @@ class TestRunBlockFiltering:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=standard_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -0,0 +1,265 @@
|
||||
"""Sandbox execution utilities for code execution tools.
|
||||
|
||||
Provides filesystem + network isolated command execution using **bubblewrap**
|
||||
(``bwrap``): whitelist-only filesystem (only system dirs visible read-only),
|
||||
writable workspace only, clean environment, network blocked.
|
||||
|
||||
Tools that call :func:`run_sandboxed` must first check :func:`has_full_sandbox`
|
||||
and refuse to run if bubblewrap is not available.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_TIMEOUT = 30
|
||||
_MAX_TIMEOUT = 120
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sandbox capability detection (cached at first call)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_BWRAP_AVAILABLE: bool | None = None
|
||||
|
||||
|
||||
def has_full_sandbox() -> bool:
|
||||
"""Return True if bubblewrap is available (filesystem + network isolation).
|
||||
|
||||
On non-Linux platforms (macOS), always returns False.
|
||||
"""
|
||||
global _BWRAP_AVAILABLE
|
||||
if _BWRAP_AVAILABLE is None:
|
||||
_BWRAP_AVAILABLE = (
|
||||
platform.system() == "Linux" and shutil.which("bwrap") is not None
|
||||
)
|
||||
return _BWRAP_AVAILABLE
|
||||
|
||||
|
||||
WORKSPACE_PREFIX = "/tmp/copilot-"
|
||||
|
||||
|
||||
def make_session_path(session_id: str) -> str:
|
||||
"""Build a sanitized, session-specific path under :data:`WORKSPACE_PREFIX`.
|
||||
|
||||
Shared by both the SDK working-directory setup and the sandbox tools so
|
||||
they always resolve to the same directory for a given session.
|
||||
|
||||
Steps:
|
||||
1. Strip all characters except ``[A-Za-z0-9-]``.
|
||||
2. Construct ``/tmp/copilot-<safe_id>``.
|
||||
3. Validate via ``os.path.normpath`` + ``startswith`` (CodeQL-recognised
|
||||
sanitizer) to prevent path traversal.
|
||||
|
||||
Raises:
|
||||
ValueError: If the resulting path escapes the prefix.
|
||||
"""
|
||||
import re
|
||||
|
||||
safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id)
|
||||
if not safe_id:
|
||||
safe_id = "default"
|
||||
path = os.path.normpath(f"{WORKSPACE_PREFIX}{safe_id}")
|
||||
if not path.startswith(WORKSPACE_PREFIX):
|
||||
raise ValueError(f"Session path escaped prefix: {path}")
|
||||
return path
|
||||
|
||||
|
||||
def get_workspace_dir(session_id: str) -> str:
|
||||
"""Get or create the workspace directory for a session.
|
||||
|
||||
Uses :func:`make_session_path` — the same path the SDK uses — so that
|
||||
bash_exec shares the workspace with the SDK file tools.
|
||||
"""
|
||||
workspace = make_session_path(session_id)
|
||||
os.makedirs(workspace, exist_ok=True)
|
||||
return workspace
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bubblewrap command builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# System directories mounted read-only inside the sandbox.
|
||||
# ONLY these are visible — /app, /root, /home, /opt, /var etc. are NOT accessible.
|
||||
_SYSTEM_RO_BINDS = [
|
||||
"/usr", # binaries, libraries, Python interpreter
|
||||
"/etc", # system config: ld.so, locale, passwd, alternatives
|
||||
]
|
||||
|
||||
# Compat paths: symlinks to /usr/* on modern Debian, real dirs on older systems.
|
||||
# On Debian 13 these are symlinks (e.g. /bin -> usr/bin). bwrap --ro-bind
|
||||
# can't create a symlink target, so we detect and use --symlink instead.
|
||||
# /lib64 is critical: the ELF dynamic linker lives at /lib64/ld-linux-x86-64.so.2.
|
||||
_COMPAT_PATHS = [
|
||||
("/bin", "usr/bin"), # -> /usr/bin on Debian 13
|
||||
("/sbin", "usr/sbin"), # -> /usr/sbin on Debian 13
|
||||
("/lib", "usr/lib"), # -> /usr/lib on Debian 13
|
||||
("/lib64", "usr/lib64"), # 64-bit libraries / ELF interpreter
|
||||
]
|
||||
|
||||
# Resource limits to prevent fork bombs, memory exhaustion, and disk abuse.
|
||||
# Applied via ulimit inside the sandbox before exec'ing the user command.
|
||||
_RESOURCE_LIMITS = (
|
||||
"ulimit -u 64" # max 64 processes (prevents fork bombs)
|
||||
" -v 524288" # 512 MB virtual memory
|
||||
" -f 51200" # 50 MB max file size (1024-byte blocks)
|
||||
" -n 256" # 256 open file descriptors
|
||||
" 2>/dev/null"
|
||||
)
|
||||
|
||||
|
||||
def _build_bwrap_command(
|
||||
command: list[str], cwd: str, env: dict[str, str]
|
||||
) -> list[str]:
|
||||
"""Build a bubblewrap command with strict filesystem + network isolation.
|
||||
|
||||
Security model:
|
||||
- **Whitelist-only filesystem**: only system directories (``/usr``, ``/etc``,
|
||||
``/bin``, ``/lib``) are mounted read-only. Application code (``/app``),
|
||||
home directories, ``/var``, ``/opt``, etc. are NOT accessible at all.
|
||||
- **Writable workspace only**: the per-session workspace is the sole
|
||||
writable path.
|
||||
- **Clean environment**: ``--clearenv`` wipes all inherited env vars.
|
||||
Only the explicitly-passed safe env vars are set inside the sandbox.
|
||||
- **Network isolation**: ``--unshare-net`` blocks all network access.
|
||||
- **Resource limits**: ulimit caps on processes (64), memory (512MB),
|
||||
file size (50MB), and open FDs (256) to prevent fork bombs and abuse.
|
||||
- **New session**: prevents terminal control escape.
|
||||
- **Die with parent**: prevents orphaned sandbox processes.
|
||||
"""
|
||||
cmd = [
|
||||
"bwrap",
|
||||
# Create a new user namespace so bwrap can set up sandboxing
|
||||
# inside unprivileged Docker containers (no CAP_SYS_ADMIN needed).
|
||||
"--unshare-user",
|
||||
# Wipe all inherited environment variables (API keys, secrets, etc.)
|
||||
"--clearenv",
|
||||
]
|
||||
|
||||
# Set only the safe env vars inside the sandbox
|
||||
for key, value in env.items():
|
||||
cmd.extend(["--setenv", key, value])
|
||||
|
||||
# System directories: read-only
|
||||
for path in _SYSTEM_RO_BINDS:
|
||||
cmd.extend(["--ro-bind", path, path])
|
||||
|
||||
# Compat paths: use --symlink when host path is a symlink (Debian 13),
|
||||
# --ro-bind when it's a real directory (older distros).
|
||||
for path, symlink_target in _COMPAT_PATHS:
|
||||
if os.path.islink(path):
|
||||
cmd.extend(["--symlink", symlink_target, path])
|
||||
elif os.path.exists(path):
|
||||
cmd.extend(["--ro-bind", path, path])
|
||||
|
||||
# Wrap the user command with resource limits:
|
||||
# sh -c 'ulimit ...; exec "$@"' -- <original command>
|
||||
# `exec "$@"` replaces the shell so there's no extra process overhead,
|
||||
# and properly handles arguments with spaces.
|
||||
limited_command = [
|
||||
"sh",
|
||||
"-c",
|
||||
f'{_RESOURCE_LIMITS}; exec "$@"',
|
||||
"--",
|
||||
*command,
|
||||
]
|
||||
|
||||
cmd.extend(
|
||||
[
|
||||
# Fresh virtual filesystems
|
||||
"--dev",
|
||||
"/dev",
|
||||
"--proc",
|
||||
"/proc",
|
||||
"--tmpfs",
|
||||
"/tmp",
|
||||
# Workspace bind AFTER --tmpfs /tmp so it's visible through the tmpfs.
|
||||
# (workspace lives under /tmp/copilot-<session>)
|
||||
"--bind",
|
||||
cwd,
|
||||
cwd,
|
||||
# Isolation
|
||||
"--unshare-net",
|
||||
"--die-with-parent",
|
||||
"--new-session",
|
||||
"--chdir",
|
||||
cwd,
|
||||
"--",
|
||||
*limited_command,
|
||||
]
|
||||
)
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_sandboxed(
|
||||
command: list[str],
|
||||
cwd: str,
|
||||
timeout: int = _DEFAULT_TIMEOUT,
|
||||
env: dict[str, str] | None = None,
|
||||
) -> tuple[str, str, int, bool]:
|
||||
"""Run a command inside a bubblewrap sandbox.
|
||||
|
||||
Callers **must** check :func:`has_full_sandbox` before calling this
|
||||
function. If bubblewrap is not available, this function raises
|
||||
:class:`RuntimeError` rather than running unsandboxed.
|
||||
|
||||
Returns:
|
||||
(stdout, stderr, exit_code, timed_out)
|
||||
"""
|
||||
if not has_full_sandbox():
|
||||
raise RuntimeError(
|
||||
"run_sandboxed() requires bubblewrap but bwrap is not available. "
|
||||
"Callers must check has_full_sandbox() before calling this function."
|
||||
)
|
||||
|
||||
timeout = min(max(timeout, 1), _MAX_TIMEOUT)
|
||||
|
||||
safe_env = {
|
||||
"PATH": "/usr/local/bin:/usr/bin:/bin",
|
||||
"HOME": cwd,
|
||||
"TMPDIR": cwd,
|
||||
"LANG": "en_US.UTF-8",
|
||||
"PYTHONDONTWRITEBYTECODE": "1",
|
||||
"PYTHONIOENCODING": "utf-8",
|
||||
}
|
||||
if env:
|
||||
safe_env.update(env)
|
||||
|
||||
full_command = _build_bwrap_command(command, cwd, safe_env)
|
||||
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*full_command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
env=safe_env,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout_bytes, stderr_bytes = await asyncio.wait_for(
|
||||
proc.communicate(), timeout=timeout
|
||||
)
|
||||
stdout = stdout_bytes.decode("utf-8", errors="replace")
|
||||
stderr = stderr_bytes.decode("utf-8", errors="replace")
|
||||
return stdout, stderr, proc.returncode or 0, False
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.communicate()
|
||||
return "", f"Execution timed out after {timeout}s", -1, True
|
||||
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
return "", f"Sandbox error: {e}", -1, False
|
||||
@@ -5,16 +5,16 @@ from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.copilot.tools.models import (
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
DocSearchResult,
|
||||
DocSearchResultsResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.data.db_accessors import search
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -117,7 +117,7 @@ class SearchDocsTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Search using hybrid search for DOCUMENTATION content type only
|
||||
results, total = await search().unified_hybrid_search(
|
||||
results, total = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.DOCUMENTATION],
|
||||
page=1,
|
||||
@@ -4,9 +4,9 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.models import BlockDetailsResponse
|
||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.models import BlockDetailsResponse
|
||||
from backend.copilot.tools.run_block import RunBlockTool
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
@@ -3,8 +3,9 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library import model as library_model
|
||||
from backend.data.db_accessors import library_db, store_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import (
|
||||
Credentials,
|
||||
@@ -37,14 +38,13 @@ async def fetch_graph_from_store_slug(
|
||||
Raises:
|
||||
DatabaseError: If there's a database error during lookup.
|
||||
"""
|
||||
sdb = store_db()
|
||||
try:
|
||||
store_agent = await sdb.get_store_agent_details(username, agent_name)
|
||||
store_agent = await store_db.get_store_agent_details(username, agent_name)
|
||||
except NotFoundError:
|
||||
return None, None
|
||||
|
||||
# Get the graph from store listing version
|
||||
graph = await sdb.get_available_graph(
|
||||
graph = await store_db.get_available_graph(
|
||||
store_agent.store_listing_version_id, hide_nodes=False
|
||||
)
|
||||
return graph, store_agent
|
||||
@@ -209,13 +209,13 @@ async def get_or_create_library_agent(
|
||||
Returns:
|
||||
LibraryAgent instance
|
||||
"""
|
||||
existing = await library_db().get_library_agent_by_graph_id(
|
||||
existing = await library_db.get_library_agent_by_graph_id(
|
||||
graph_id=graph.id, user_id=user_id
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
library_agents = await library_db().create_library_agent(
|
||||
library_agents = await library_db.create_library_agent(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
@@ -0,0 +1,151 @@
|
||||
"""Web fetch tool — safely retrieve public web page content."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import html2text
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
WebFetchResponse,
|
||||
)
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Limits
|
||||
_MAX_CONTENT_BYTES = 102_400 # 100 KB download cap
|
||||
_REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=15)
|
||||
|
||||
# Content types we'll read as text
|
||||
_TEXT_CONTENT_TYPES = {
|
||||
"text/html",
|
||||
"text/plain",
|
||||
"text/xml",
|
||||
"text/csv",
|
||||
"text/markdown",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/xhtml+xml",
|
||||
"application/rss+xml",
|
||||
"application/atom+xml",
|
||||
}
|
||||
|
||||
|
||||
def _is_text_content(content_type: str) -> bool:
|
||||
base = content_type.split(";")[0].strip().lower()
|
||||
return base in _TEXT_CONTENT_TYPES or base.startswith("text/")
|
||||
|
||||
|
||||
def _html_to_text(html: str) -> str:
|
||||
h = html2text.HTML2Text()
|
||||
h.ignore_links = False
|
||||
h.ignore_images = True
|
||||
h.body_width = 0
|
||||
return h.handle(html)
|
||||
|
||||
|
||||
class WebFetchTool(BaseTool):
|
||||
"""Safely fetch content from a public URL using SSRF-protected HTTP."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "web_fetch"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Fetch the content of a public web page by URL. "
|
||||
"Returns readable text extracted from HTML by default. "
|
||||
"Useful for reading documentation, articles, and API responses. "
|
||||
"Only supports HTTP/HTTPS GET requests to public URLs "
|
||||
"(private/internal network addresses are blocked)."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The public HTTP/HTTPS URL to fetch.",
|
||||
},
|
||||
"extract_text": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true (default), extract readable text from HTML. "
|
||||
"If false, return raw content."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
url: str = (kwargs.get("url") or "").strip()
|
||||
extract_text: bool = kwargs.get("extract_text", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not url:
|
||||
return ErrorResponse(
|
||||
message="Please provide a URL to fetch.",
|
||||
error="missing_url",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
client = Requests(raise_for_status=False, retry_max_attempts=1)
|
||||
response = await client.get(url, timeout=_REQUEST_TIMEOUT)
|
||||
except ValueError as e:
|
||||
# validate_url raises ValueError for SSRF / blocked IPs
|
||||
return ErrorResponse(
|
||||
message=f"URL blocked: {e}",
|
||||
error="url_blocked",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[web_fetch] Request failed for {url}: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Failed to fetch URL: {e}",
|
||||
error="fetch_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if not _is_text_content(content_type):
|
||||
return ErrorResponse(
|
||||
message=f"Non-text content type: {content_type.split(';')[0]}",
|
||||
error="unsupported_content_type",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
raw = response.content[:_MAX_CONTENT_BYTES]
|
||||
text = raw.decode("utf-8", errors="replace")
|
||||
|
||||
if extract_text and "html" in content_type.lower():
|
||||
text = _html_to_text(text)
|
||||
|
||||
return WebFetchResponse(
|
||||
message=f"Fetched {url}",
|
||||
url=response.url,
|
||||
status_code=response.status,
|
||||
content_type=content_type.split(";")[0].strip(),
|
||||
content=text,
|
||||
truncated=False,
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -6,8 +6,8 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
@@ -88,7 +88,9 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List files in the user's workspace. "
|
||||
"List files in the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Read/Glob tools instead. "
|
||||
"Returns file names, paths, sizes, and metadata. "
|
||||
"Optionally filter by path prefix."
|
||||
)
|
||||
@@ -146,7 +148,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
@@ -204,7 +206,9 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read a file from the user's workspace. "
|
||||
"Read a file from the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Read tool instead. "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"For small text files, returns content directly. "
|
||||
"For large or binary files, returns metadata and a download URL. "
|
||||
@@ -280,7 +284,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
@@ -378,7 +382,9 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Write or create a file in the user's workspace. "
|
||||
"Write or create a file in the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Write tool instead. "
|
||||
"Provide the content as a base64-encoded string. "
|
||||
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||
"Files are saved to the current session's folder by default. "
|
||||
@@ -478,7 +484,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
@@ -523,7 +529,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Delete a file from the user's workspace. "
|
||||
"Delete a file from the user's persistent workspace (cloud storage). "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
@@ -577,7 +583,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
@@ -40,11 +40,11 @@ import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.copilot.completion_consumer import (
|
||||
from backend.api.features.chat.completion_consumer import (
|
||||
start_completion_consumer,
|
||||
stop_completion_consumer,
|
||||
)
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
|
||||
@@ -38,9 +38,7 @@ def main(**kwargs):
|
||||
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.api.ws_api import WebsocketServer
|
||||
from backend.copilot.executor.manager import CoPilotExecutor
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
from backend.executor import ExecutionManager, Scheduler
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.notifications import NotificationManager
|
||||
|
||||
run_processes(
|
||||
@@ -50,7 +48,6 @@ def main(**kwargs):
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
ExecutionManager(),
|
||||
CoPilotExecutor(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -28,54 +27,6 @@ async def server():
|
||||
yield server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID fixture."""
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
"""Admin user ID fixture."""
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_id() -> str:
|
||||
"""Target user ID fixture."""
|
||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user(test_user_id):
|
||||
"""Create test user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the test user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": test_user_id,
|
||||
"email": "test@example.com",
|
||||
"user_metadata": {"name": "Test User"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_admin_user(admin_user_id):
|
||||
"""Create admin user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the admin user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": admin_user_id,
|
||||
"email": "test-admin@example.com",
|
||||
"user_metadata": {"name": "Test Admin"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return admin_user_id
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
|
||||
async def graph_cleanup(server):
|
||||
created_graph_ids = []
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
"""CoPilot module - AI assistant for AutoGPT platform.
|
||||
|
||||
This module contains the core CoPilot functionality including:
|
||||
- AI generation service (LLM calls)
|
||||
- Tool execution
|
||||
- Session management
|
||||
- Stream registry for SSE reconnection
|
||||
"""
|
||||
@@ -1,5 +0,0 @@
|
||||
"""CoPilot Executor - Dedicated service for AI generation and tool execution.
|
||||
|
||||
This module contains the executor service that processes CoPilot tasks
|
||||
from RabbitMQ, following the graph executor pattern.
|
||||
"""
|
||||
@@ -1,18 +0,0 @@
|
||||
"""Entry point for running the CoPilot Executor service.
|
||||
|
||||
Usage:
|
||||
python -m backend.copilot.executor
|
||||
"""
|
||||
|
||||
from backend.app import run_processes
|
||||
|
||||
from .manager import CoPilotExecutor
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the CoPilot Executor service."""
|
||||
run_processes(CoPilotExecutor())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,508 +0,0 @@
|
||||
"""CoPilot Executor Manager - main service for CoPilot task execution.
|
||||
|
||||
This module contains the CoPilotExecutor class that consumes chat tasks from
|
||||
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.exceptions import AMQPChannelError, AMQPConnectionError
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.process import AppProcess
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .processor import execute_copilot_task, init_worker
|
||||
from .utils import (
|
||||
COPILOT_CANCEL_QUEUE_NAME,
|
||||
COPILOT_EXECUTION_QUEUE_NAME,
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
CancelCoPilotEvent,
|
||||
CoPilotExecutionEntry,
|
||||
create_copilot_queue_config,
|
||||
)
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
settings = Settings()
|
||||
|
||||
# Prometheus metrics
|
||||
active_tasks_gauge = Gauge(
|
||||
"copilot_executor_active_tasks",
|
||||
"Number of active CoPilot tasks",
|
||||
)
|
||||
pool_size_gauge = Gauge(
|
||||
"copilot_executor_pool_size",
|
||||
"Maximum number of CoPilot executor workers",
|
||||
)
|
||||
utilization_gauge = Gauge(
|
||||
"copilot_executor_utilization_ratio",
|
||||
"Ratio of active tasks to pool size",
|
||||
)
|
||||
|
||||
|
||||
class CoPilotExecutor(AppProcess):
|
||||
"""CoPilot Executor service for processing chat generation tasks.
|
||||
|
||||
This service consumes tasks from RabbitMQ, processes them using a thread pool,
|
||||
and publishes results to Redis Streams. It follows the graph executor pattern
|
||||
for reliable message handling and graceful shutdown.
|
||||
|
||||
Key features:
|
||||
- RabbitMQ-based task distribution with manual acknowledgment
|
||||
- Thread pool executor for concurrent task processing
|
||||
- Cluster lock for duplicate prevention across pods
|
||||
- Graceful shutdown with timeout for in-flight tasks
|
||||
- FANOUT exchange for cancellation broadcast
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pool_size = settings.config.num_copilot_workers
|
||||
self.active_tasks: dict[str, tuple[Future, threading.Event]] = {}
|
||||
self.executor_id = str(uuid.uuid4())
|
||||
|
||||
self._executor = None
|
||||
self._stop_consuming = None
|
||||
|
||||
self._cancel_thread = None
|
||||
self._cancel_client = None
|
||||
self._run_thread = None
|
||||
self._run_client = None
|
||||
|
||||
self._task_locks: dict[str, ClusterLock] = {}
|
||||
|
||||
# ============ Main Entry Points (AppProcess interface) ============ #
|
||||
|
||||
def run(self):
|
||||
"""Main service loop - consume from RabbitMQ."""
|
||||
logger.info(f"Pod assigned executor_id: {self.executor_id}")
|
||||
logger.info(f"Spawn max-{self.pool_size} workers...")
|
||||
|
||||
pool_size_gauge.set(self.pool_size)
|
||||
self._update_metrics()
|
||||
start_http_server(settings.config.copilot_executor_port)
|
||||
|
||||
self.cancel_thread.start()
|
||||
self.run_thread.start()
|
||||
|
||||
while True:
|
||||
time.sleep(1e5)
|
||||
|
||||
def cleanup(self):
|
||||
"""Graceful shutdown with active execution waiting."""
|
||||
pid = os.getpid()
|
||||
logger.info(f"[cleanup {pid}] Starting graceful shutdown...")
|
||||
|
||||
# Signal the consumer thread to stop
|
||||
try:
|
||||
self.stop_consuming.set()
|
||||
run_channel = self.run_client.get_channel()
|
||||
run_channel.connection.add_callback_threadsafe(
|
||||
lambda: run_channel.stop_consuming()
|
||||
)
|
||||
logger.info(f"[cleanup {pid}] Consumer has been signaled to stop")
|
||||
except Exception as e:
|
||||
logger.error(f"[cleanup {pid}] Error stopping consumer: {e}")
|
||||
|
||||
# Wait for active executions to complete
|
||||
if self.active_tasks:
|
||||
logger.info(
|
||||
f"[cleanup {pid}] Waiting for {len(self.active_tasks)} active tasks to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..."
|
||||
)
|
||||
|
||||
start_time = time.monotonic()
|
||||
last_refresh = start_time
|
||||
lock_refresh_interval = settings.config.cluster_lock_timeout / 10
|
||||
|
||||
while (
|
||||
self.active_tasks
|
||||
and (time.monotonic() - start_time) < GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS
|
||||
):
|
||||
self._cleanup_completed_tasks()
|
||||
if not self.active_tasks:
|
||||
break
|
||||
|
||||
# Refresh cluster locks periodically
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= lock_refresh_interval:
|
||||
for lock in self._task_locks.values():
|
||||
try:
|
||||
lock.refresh()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[cleanup {pid}] Failed to refresh lock: {e}"
|
||||
)
|
||||
last_refresh = current_time
|
||||
|
||||
logger.info(
|
||||
f"[cleanup {pid}] {len(self.active_tasks)} tasks still active, waiting..."
|
||||
)
|
||||
time.sleep(10.0)
|
||||
|
||||
# Stop message consumers
|
||||
if self._run_thread:
|
||||
self._stop_message_consumers(
|
||||
self._run_thread, self.run_client, "[cleanup][run]"
|
||||
)
|
||||
if self._cancel_thread:
|
||||
self._stop_message_consumers(
|
||||
self._cancel_thread, self.cancel_client, "[cleanup][cancel]"
|
||||
)
|
||||
|
||||
# Shutdown executor
|
||||
if self._executor:
|
||||
logger.info(f"[cleanup {pid}] Shutting down executor...")
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
# Release any remaining locks
|
||||
for task_id, lock in list(self._task_locks.items()):
|
||||
try:
|
||||
lock.release()
|
||||
logger.info(f"[cleanup {pid}] Released lock for {task_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[cleanup {pid}] Failed to release lock for {task_id}: {e}"
|
||||
)
|
||||
|
||||
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
|
||||
|
||||
# ============ RabbitMQ Consumer Methods ============ #
|
||||
|
||||
@continuous_retry()
|
||||
def _consume_cancel(self):
|
||||
"""Consume cancellation messages from FANOUT exchange."""
|
||||
if self.stop_consuming.is_set() and not self.active_tasks:
|
||||
logger.info("Stop reconnecting cancel consumer - service cleaned up")
|
||||
return
|
||||
|
||||
if not self.cancel_client.is_ready:
|
||||
self.cancel_client.disconnect()
|
||||
self.cancel_client.connect()
|
||||
|
||||
# Check again after connect - shutdown may have been requested
|
||||
if self.stop_consuming.is_set() and not self.active_tasks:
|
||||
logger.info("Stop consuming requested during reconnect - disconnecting")
|
||||
self.cancel_client.disconnect()
|
||||
return
|
||||
|
||||
cancel_channel = self.cancel_client.get_channel()
|
||||
cancel_channel.basic_consume(
|
||||
queue=COPILOT_CANCEL_QUEUE_NAME,
|
||||
on_message_callback=self._handle_cancel_message,
|
||||
auto_ack=True,
|
||||
)
|
||||
logger.info("Starting cancel message consumer...")
|
||||
cancel_channel.start_consuming()
|
||||
if not self.stop_consuming.is_set() or self.active_tasks:
|
||||
raise RuntimeError("Cancel message consumer stopped unexpectedly")
|
||||
logger.info("Cancel message consumer stopped gracefully")
|
||||
|
||||
@continuous_retry()
|
||||
def _consume_run(self):
|
||||
"""Consume run messages from DIRECT exchange."""
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Stop reconnecting run consumer - service cleaned up")
|
||||
return
|
||||
|
||||
if not self.run_client.is_ready:
|
||||
self.run_client.disconnect()
|
||||
self.run_client.connect()
|
||||
|
||||
# Check again after connect - shutdown may have been requested
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Stop consuming requested during reconnect - disconnecting")
|
||||
self.run_client.disconnect()
|
||||
return
|
||||
|
||||
run_channel = self.run_client.get_channel()
|
||||
run_channel.basic_qos(prefetch_count=self.pool_size)
|
||||
|
||||
run_channel.basic_consume(
|
||||
queue=COPILOT_EXECUTION_QUEUE_NAME,
|
||||
on_message_callback=self._handle_run_message,
|
||||
auto_ack=False,
|
||||
consumer_tag="copilot_execution_consumer",
|
||||
)
|
||||
logger.info("Starting to consume run messages...")
|
||||
run_channel.start_consuming()
|
||||
if not self.stop_consuming.is_set():
|
||||
raise RuntimeError("Run message consumer stopped unexpectedly")
|
||||
logger.info("Run message consumer stopped gracefully")
|
||||
|
||||
# ============ Message Handlers ============ #
|
||||
|
||||
@error_logged(swallow=True)
|
||||
def _handle_cancel_message(
|
||||
self,
|
||||
_channel: BlockingChannel,
|
||||
_method: Basic.Deliver,
|
||||
_properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
"""Handle cancel message from FANOUT exchange."""
|
||||
request = CancelCoPilotEvent.model_validate_json(body)
|
||||
task_id = request.task_id
|
||||
if not task_id:
|
||||
logger.warning("Cancel message missing 'task_id'")
|
||||
return
|
||||
if task_id not in self.active_tasks:
|
||||
logger.debug(f"Cancel received for {task_id} but not active")
|
||||
return
|
||||
|
||||
_, cancel_event = self.active_tasks[task_id]
|
||||
logger.info(f"Received cancel for {task_id}")
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
else:
|
||||
logger.debug(f"Cancel already set for {task_id}")
|
||||
|
||||
def _handle_run_message(
|
||||
self,
|
||||
_channel: BlockingChannel,
|
||||
method: Basic.Deliver,
|
||||
_properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
"""Handle run message from DIRECT exchange."""
|
||||
delivery_tag = method.delivery_tag
|
||||
# Capture the channel used at message delivery time to ensure we ack
|
||||
# on the correct channel. Delivery tags are channel-scoped and become
|
||||
# invalid if the channel is recreated after reconnection.
|
||||
delivery_channel = _channel
|
||||
|
||||
def ack_message(reject: bool, requeue: bool):
|
||||
"""Acknowledge or reject the message.
|
||||
|
||||
Uses the channel from the original message delivery. If the channel
|
||||
is no longer open (e.g., after reconnection), logs a warning and
|
||||
skips the ack - RabbitMQ will redeliver the message automatically.
|
||||
"""
|
||||
try:
|
||||
if not delivery_channel.is_open:
|
||||
logger.warning(
|
||||
f"Channel closed, cannot ack delivery_tag={delivery_tag}. "
|
||||
"Message will be redelivered by RabbitMQ."
|
||||
)
|
||||
return
|
||||
|
||||
if reject:
|
||||
delivery_channel.connection.add_callback_threadsafe(
|
||||
lambda: delivery_channel.basic_nack(
|
||||
delivery_tag, requeue=requeue
|
||||
)
|
||||
)
|
||||
else:
|
||||
delivery_channel.connection.add_callback_threadsafe(
|
||||
lambda: delivery_channel.basic_ack(delivery_tag)
|
||||
)
|
||||
except (AMQPChannelError, AMQPConnectionError) as e:
|
||||
# Channel/connection errors indicate stale delivery tag - don't retry
|
||||
logger.warning(
|
||||
f"Cannot ack delivery_tag={delivery_tag} due to channel/connection "
|
||||
f"error: {e}. Message will be redelivered by RabbitMQ."
|
||||
)
|
||||
except Exception as e:
|
||||
# Other errors might be transient, but log and skip to avoid blocking
|
||||
logger.error(
|
||||
f"Unexpected error acking delivery_tag={delivery_tag}: {e}"
|
||||
)
|
||||
|
||||
# Check if we're shutting down
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Rejecting new task during shutdown")
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Check if we can accept more tasks
|
||||
self._cleanup_completed_tasks()
|
||||
if len(self.active_tasks) >= self.pool_size:
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
try:
|
||||
entry = CoPilotExecutionEntry.model_validate_json(body)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not parse run message: {e}, body={body}")
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
task_id = entry.task_id
|
||||
|
||||
# Check for local duplicate - task is already running on this executor
|
||||
if task_id in self.active_tasks:
|
||||
logger.warning(
|
||||
f"Task {task_id} already running locally, rejecting duplicate"
|
||||
)
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
# Try to acquire cluster-wide lock
|
||||
cluster_lock = ClusterLock(
|
||||
redis=redis.get_redis(),
|
||||
key=f"copilot:task:{task_id}:lock",
|
||||
owner_id=self.executor_id,
|
||||
timeout=settings.config.cluster_lock_timeout,
|
||||
)
|
||||
current_owner = cluster_lock.try_acquire()
|
||||
if current_owner != self.executor_id:
|
||||
if current_owner is not None:
|
||||
logger.warning(f"Task {task_id} already running on pod {current_owner}")
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not acquire lock for {task_id} - Redis unavailable"
|
||||
)
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Execute the task
|
||||
try:
|
||||
self._task_locks[task_id] = cluster_lock
|
||||
|
||||
logger.info(
|
||||
f"Acquired cluster lock for {task_id}, executor_id={self.executor_id}"
|
||||
)
|
||||
|
||||
cancel_event = threading.Event()
|
||||
future = self.executor.submit(
|
||||
execute_copilot_task, entry, cancel_event, cluster_lock
|
||||
)
|
||||
self.active_tasks[task_id] = (future, cancel_event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup execution for {task_id}: {e}")
|
||||
cluster_lock.release()
|
||||
if task_id in self._task_locks:
|
||||
del self._task_locks[task_id]
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
self._update_metrics()
|
||||
|
||||
def on_run_done(f: Future):
|
||||
logger.info(f"Run completed for {task_id}")
|
||||
try:
|
||||
if exec_error := f.exception():
|
||||
logger.error(f"Execution for {task_id} failed: {exec_error}")
|
||||
# Don't requeue failed tasks - they've been marked as failed
|
||||
# in the stream registry. Requeuing would cause infinite retries
|
||||
# for deterministic failures.
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
ack_message(reject=False, requeue=False)
|
||||
except BaseException as e:
|
||||
logger.exception(f"Error in run completion callback: {e}")
|
||||
finally:
|
||||
# Release the cluster lock
|
||||
if task_id in self._task_locks:
|
||||
logger.info(f"Releasing cluster lock for {task_id}")
|
||||
self._task_locks[task_id].release()
|
||||
del self._task_locks[task_id]
|
||||
self._cleanup_completed_tasks()
|
||||
|
||||
future.add_done_callback(on_run_done)
|
||||
|
||||
# ============ Helper Methods ============ #
|
||||
|
||||
def _cleanup_completed_tasks(self) -> list[str]:
|
||||
"""Remove completed futures from active_tasks and update metrics."""
|
||||
completed_tasks = []
|
||||
for task_id, (future, _) in self.active_tasks.items():
|
||||
if future.done():
|
||||
completed_tasks.append(task_id)
|
||||
|
||||
for task_id in completed_tasks:
|
||||
logger.info(f"Cleaned up completed task {task_id}")
|
||||
self.active_tasks.pop(task_id, None)
|
||||
|
||||
self._update_metrics()
|
||||
return completed_tasks
|
||||
|
||||
def _update_metrics(self):
|
||||
"""Update Prometheus metrics."""
|
||||
active_count = len(self.active_tasks)
|
||||
active_tasks_gauge.set(active_count)
|
||||
if self.stop_consuming.is_set():
|
||||
utilization_gauge.set(1.0)
|
||||
else:
|
||||
utilization_gauge.set(
|
||||
active_count / self.pool_size if self.pool_size > 0 else 0
|
||||
)
|
||||
|
||||
def _stop_message_consumers(
|
||||
self, thread: threading.Thread, client: SyncRabbitMQ, prefix: str
|
||||
):
|
||||
"""Stop a message consumer thread."""
|
||||
try:
|
||||
channel = client.get_channel()
|
||||
channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming())
|
||||
|
||||
thread.join(timeout=300)
|
||||
if thread.is_alive():
|
||||
logger.error(
|
||||
f"{prefix} Thread did not finish in time, forcing disconnect"
|
||||
)
|
||||
|
||||
client.disconnect()
|
||||
logger.info(f"{prefix} Client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} Error disconnecting client: {e}")
|
||||
|
||||
# ============ Lazy-initialized Properties ============ #
|
||||
|
||||
@property
|
||||
def cancel_thread(self) -> threading.Thread:
|
||||
if self._cancel_thread is None:
|
||||
self._cancel_thread = threading.Thread(
|
||||
target=lambda: self._consume_cancel(),
|
||||
daemon=True,
|
||||
)
|
||||
return self._cancel_thread
|
||||
|
||||
@property
|
||||
def run_thread(self) -> threading.Thread:
|
||||
if self._run_thread is None:
|
||||
self._run_thread = threading.Thread(
|
||||
target=lambda: self._consume_run(),
|
||||
daemon=True,
|
||||
)
|
||||
return self._run_thread
|
||||
|
||||
@property
|
||||
def stop_consuming(self) -> threading.Event:
|
||||
if self._stop_consuming is None:
|
||||
self._stop_consuming = threading.Event()
|
||||
return self._stop_consuming
|
||||
|
||||
@property
|
||||
def executor(self) -> ThreadPoolExecutor:
|
||||
if self._executor is None:
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=init_worker,
|
||||
)
|
||||
return self._executor
|
||||
|
||||
@property
|
||||
def cancel_client(self) -> SyncRabbitMQ:
|
||||
if self._cancel_client is None:
|
||||
self._cancel_client = SyncRabbitMQ(create_copilot_queue_config())
|
||||
return self._cancel_client
|
||||
|
||||
@property
|
||||
def run_client(self) -> SyncRabbitMQ:
|
||||
if self._run_client is None:
|
||||
self._run_client = SyncRabbitMQ(create_copilot_queue_config())
|
||||
return self._run_client
|
||||
@@ -1,237 +0,0 @@
|
||||
"""CoPilot execution processor - per-worker execution logic.
|
||||
|
||||
This module contains the processor class that handles CoPilot task execution
|
||||
in a thread-local context, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
from backend.copilot import service as copilot_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.process import set_service_name
|
||||
from backend.util.retry import func_retry
|
||||
|
||||
from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
|
||||
|
||||
# ============ Module Entry Points ============ #
|
||||
|
||||
# Thread-local storage for processor instances
|
||||
_tls = threading.local()
|
||||
|
||||
|
||||
def execute_copilot_task(
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute a CoPilot task using the thread-local processor.
|
||||
|
||||
This function is the entry point called by the thread pool executor.
|
||||
|
||||
Args:
|
||||
entry: The task payload
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock for this execution
|
||||
"""
|
||||
processor: CoPilotProcessor = _tls.processor
|
||||
return processor.execute(entry, cancel, cluster_lock)
|
||||
|
||||
|
||||
def init_worker():
|
||||
"""Initialize the processor for the current worker thread.
|
||||
|
||||
This function is called by the thread pool executor when a new worker
|
||||
thread is created. It ensures each worker has its own processor instance.
|
||||
"""
|
||||
_tls.processor = CoPilotProcessor()
|
||||
_tls.processor.on_executor_start()
|
||||
|
||||
|
||||
# ============ Processor Class ============ #
|
||||
|
||||
|
||||
class CoPilotProcessor:
|
||||
"""Per-worker execution logic for CoPilot tasks.
|
||||
|
||||
This class is instantiated once per worker thread and handles the execution
|
||||
of CoPilot chat generation tasks. It maintains an async event loop for
|
||||
running the async service code.
|
||||
|
||||
The execution flow:
|
||||
1. CoPilot task is picked from RabbitMQ queue
|
||||
2. Manager submits task to thread pool
|
||||
3. Processor executes the task in its event loop
|
||||
4. Results are published to Redis Streams
|
||||
"""
|
||||
|
||||
@func_retry
|
||||
def on_executor_start(self):
|
||||
"""Initialize the processor when the worker thread starts.
|
||||
|
||||
This method is called once per worker thread to set up the async event
|
||||
loop and initialize any required resources.
|
||||
|
||||
Database is accessed only through DatabaseManager, so we don't need to connect
|
||||
to Prisma directly.
|
||||
"""
|
||||
configure_logging()
|
||||
set_service_name("CoPilotExecutor")
|
||||
self.tid = threading.get_ident()
|
||||
self.execution_loop = asyncio.new_event_loop()
|
||||
self.execution_thread = threading.Thread(
|
||||
target=self.execution_loop.run_forever, daemon=True
|
||||
)
|
||||
self.execution_thread.start()
|
||||
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
|
||||
|
||||
@error_logged(swallow=False)
|
||||
def execute(
|
||||
self,
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute a CoPilot task.
|
||||
|
||||
This is the main entry point for task execution. It runs the async
|
||||
execution logic in the worker's event loop and handles errors.
|
||||
|
||||
Args:
|
||||
entry: The task payload containing session and message info
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock to prevent duplicate execution
|
||||
"""
|
||||
log = CoPilotLogMetadata(
|
||||
logging.getLogger(__name__),
|
||||
task_id=entry.task_id,
|
||||
session_id=entry.session_id,
|
||||
user_id=entry.user_id,
|
||||
)
|
||||
log.info("Starting execution")
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
try:
|
||||
# Run the async execution in our event loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._execute_async(entry, cancel, cluster_lock, log),
|
||||
self.execution_loop,
|
||||
)
|
||||
|
||||
# Wait for completion, checking cancel periodically
|
||||
while not future.done():
|
||||
try:
|
||||
future.result(timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
if cancel.is_set():
|
||||
log.info("Cancellation requested")
|
||||
future.cancel()
|
||||
break
|
||||
# Refresh cluster lock to maintain ownership
|
||||
cluster_lock.refresh()
|
||||
|
||||
if not future.cancelled():
|
||||
# Get result to propagate any exceptions
|
||||
future.result()
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.info(f"Execution completed in {elapsed:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.error(f"Execution failed after {elapsed:.2f}s: {e}")
|
||||
# Note: _execute_async already marks the task as failed before re-raising,
|
||||
# so we don't call _mark_task_failed here to avoid duplicate error events.
|
||||
raise
|
||||
|
||||
async def _execute_async(
|
||||
self,
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
log: CoPilotLogMetadata,
|
||||
):
|
||||
"""Async execution logic for CoPilot task.
|
||||
|
||||
This method calls the existing stream_chat_completion service function
|
||||
and publishes results to the stream registry.
|
||||
|
||||
Args:
|
||||
entry: The task payload
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock for refresh
|
||||
log: Structured logger for this task
|
||||
"""
|
||||
last_refresh = time.monotonic()
|
||||
refresh_interval = 30.0 # Refresh lock every 30 seconds
|
||||
|
||||
try:
|
||||
# Stream chat completion and publish chunks to Redis
|
||||
async for chunk in copilot_service.stream_chat_completion(
|
||||
session_id=entry.session_id,
|
||||
message=entry.message if entry.message else None,
|
||||
is_user_message=entry.is_user_message,
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
_task_id=entry.task_id,
|
||||
):
|
||||
# Check for cancellation
|
||||
if cancel.is_set():
|
||||
log.info("Cancelled during streaming")
|
||||
await stream_registry.publish_chunk(
|
||||
entry.task_id, StreamError(errorText="Operation cancelled")
|
||||
)
|
||||
await stream_registry.publish_chunk(
|
||||
entry.task_id, StreamFinishStep()
|
||||
)
|
||||
await stream_registry.publish_chunk(entry.task_id, StreamFinish())
|
||||
await stream_registry.mark_task_completed(
|
||||
entry.task_id, status="failed"
|
||||
)
|
||||
return
|
||||
|
||||
# Refresh cluster lock periodically
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= refresh_interval:
|
||||
cluster_lock.refresh()
|
||||
last_refresh = current_time
|
||||
|
||||
# Publish chunk to stream registry
|
||||
await stream_registry.publish_chunk(entry.task_id, chunk)
|
||||
|
||||
# Mark task as completed
|
||||
await stream_registry.mark_task_completed(entry.task_id, status="completed")
|
||||
log.info("Task completed successfully")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
log.info("Task cancelled")
|
||||
await stream_registry.mark_task_completed(entry.task_id, status="failed")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Task failed: {e}")
|
||||
await self._mark_task_failed(entry.task_id, str(e))
|
||||
raise
|
||||
|
||||
async def _mark_task_failed(self, task_id: str, error_message: str):
|
||||
"""Mark a task as failed and publish error to stream registry."""
|
||||
try:
|
||||
await stream_registry.publish_chunk(
|
||||
task_id, StreamError(errorText=error_message)
|
||||
)
|
||||
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mark task {task_id} as failed: {e}")
|
||||
@@ -1,207 +0,0 @@
|
||||
"""RabbitMQ queue configuration for CoPilot executor.
|
||||
|
||||
Defines two exchanges and queues following the graph executor pattern:
|
||||
- 'copilot_execution' (DIRECT) for chat generation tasks
|
||||
- 'copilot_cancel' (FANOUT) for cancellation requests
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============ Logging Helper ============ #
|
||||
|
||||
|
||||
class CoPilotLogMetadata(TruncatedLogger):
|
||||
"""Structured logging helper for CoPilot executor.
|
||||
|
||||
In cloud environments (structured logging enabled), uses a simple prefix
|
||||
and passes metadata via json_fields. In local environments, uses a detailed
|
||||
prefix with all metadata key-value pairs for easier debugging.
|
||||
|
||||
Args:
|
||||
logger: The underlying logger instance
|
||||
max_length: Maximum log message length before truncation
|
||||
**kwargs: Metadata key-value pairs (e.g., task_id="abc", session_id="xyz")
|
||||
These are added to json_fields in cloud mode, or to the prefix in local mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger: logging.Logger,
|
||||
max_length: int = 1000,
|
||||
**kwargs: str | None,
|
||||
):
|
||||
# Filter out None values
|
||||
metadata = {k: v for k, v in kwargs.items() if v is not None}
|
||||
metadata["component"] = "CoPilotExecutor"
|
||||
|
||||
if is_structured_logging_enabled():
|
||||
prefix = "[CoPilotExecutor]"
|
||||
else:
|
||||
# Build prefix from metadata key-value pairs
|
||||
meta_parts = "|".join(
|
||||
f"{k}:{v}" for k, v in metadata.items() if k != "component"
|
||||
)
|
||||
prefix = (
|
||||
f"[CoPilotExecutor|{meta_parts}]" if meta_parts else "[CoPilotExecutor]"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
logger,
|
||||
max_length=max_length,
|
||||
prefix=prefix,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
# ============ Exchange and Queue Configuration ============ #
|
||||
|
||||
COPILOT_EXECUTION_EXCHANGE = Exchange(
|
||||
name="copilot_execution",
|
||||
type=ExchangeType.DIRECT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
COPILOT_EXECUTION_QUEUE_NAME = "copilot_execution_queue"
|
||||
COPILOT_EXECUTION_ROUTING_KEY = "copilot.run"
|
||||
|
||||
COPILOT_CANCEL_EXCHANGE = Exchange(
|
||||
name="copilot_cancel",
|
||||
type=ExchangeType.FANOUT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue"
|
||||
|
||||
# CoPilot operations can include extended thinking and agent generation
|
||||
# which may take 30+ minutes to complete
|
||||
COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour
|
||||
|
||||
# Graceful shutdown timeout - allow in-flight operations to complete
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def create_copilot_queue_config() -> RabbitMQConfig:
|
||||
"""Create RabbitMQ configuration for CoPilot executor.
|
||||
|
||||
Defines two exchanges and queues:
|
||||
- 'copilot_execution' (DIRECT) for chat generation tasks
|
||||
- 'copilot_cancel' (FANOUT) for cancellation requests
|
||||
|
||||
Returns:
|
||||
RabbitMQConfig with exchanges and queues defined
|
||||
"""
|
||||
run_queue = Queue(
|
||||
name=COPILOT_EXECUTION_QUEUE_NAME,
|
||||
exchange=COPILOT_EXECUTION_EXCHANGE,
|
||||
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
arguments={
|
||||
# Extended consumer timeout for long-running LLM operations
|
||||
# Default 30-minute timeout is insufficient for extended thinking
|
||||
# and agent generation which can take 30+ minutes
|
||||
"x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS
|
||||
* 1000,
|
||||
},
|
||||
)
|
||||
cancel_queue = Queue(
|
||||
name=COPILOT_CANCEL_QUEUE_NAME,
|
||||
exchange=COPILOT_CANCEL_EXCHANGE,
|
||||
routing_key="", # not used for FANOUT
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
return RabbitMQConfig(
|
||||
vhost="/",
|
||||
exchanges=[COPILOT_EXECUTION_EXCHANGE, COPILOT_CANCEL_EXCHANGE],
|
||||
queues=[run_queue, cancel_queue],
|
||||
)
|
||||
|
||||
|
||||
# ============ Message Models ============ #
|
||||
|
||||
|
||||
class CoPilotExecutionEntry(BaseModel):
|
||||
"""Task payload for CoPilot AI generation.
|
||||
|
||||
This model represents a chat generation task to be processed by the executor.
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
"""Unique identifier for this task (used for stream registry)"""
|
||||
|
||||
session_id: str
|
||||
"""Chat session ID"""
|
||||
|
||||
user_id: str | None
|
||||
"""User ID (may be None for anonymous users)"""
|
||||
|
||||
operation_id: str
|
||||
"""Operation ID for webhook callbacks and completion tracking"""
|
||||
|
||||
message: str
|
||||
"""User's message to process"""
|
||||
|
||||
is_user_message: bool = True
|
||||
"""Whether the message is from the user (vs system/assistant)"""
|
||||
|
||||
context: dict[str, str] | None = None
|
||||
"""Optional context for the message (e.g., {url: str, content: str})"""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
|
||||
task_id: str
|
||||
"""Task ID to cancel"""
|
||||
|
||||
|
||||
# ============ Queue Publishing Helpers ============ #
|
||||
|
||||
|
||||
async def enqueue_copilot_task(
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
operation_id: str,
|
||||
message: str,
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for this task (used for stream registry)
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (may be None for anonymous users)
|
||||
operation_id: Operation ID for webhook callbacks and completion tracking
|
||||
message: User's message to process
|
||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
entry = CoPilotExecutionEntry(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
await queue_client.publish_message(
|
||||
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
|
||||
message=entry.model_dump_json(),
|
||||
exchange=COPILOT_EXECUTION_EXCHANGE,
|
||||
)
|
||||
@@ -1,82 +0,0 @@
|
||||
import logging
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
|
||||
from . import service as chat_service
|
||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||
from .response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamTextDelta,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
assistant_message = ""
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id, "Hello, how are you?", user_id=session.user_id
|
||||
):
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
assistant_message += chunk.delta
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert assistant_message, "Assistant message is empty"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
had_tool_calls = False
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id,
|
||||
"Please find me an agent that can help me with my business. Use the query 'moneny printing agent'",
|
||||
user_id=session.user_id,
|
||||
):
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
if isinstance(chunk, StreamToolOutputAvailable):
|
||||
had_tool_calls = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert had_tool_calls, "Tool calls did not occur"
|
||||
session = await get_chat_session(session.session_id)
|
||||
assert session, "Session not found"
|
||||
assert session.usage, "Usage is empty"
|
||||
@@ -1,118 +0,0 @@
|
||||
from backend.data import db
|
||||
|
||||
|
||||
def chat_db():
|
||||
if db.is_connected():
|
||||
from backend.copilot import db as _chat_db
|
||||
|
||||
chat_db = _chat_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
chat_db = get_database_manager_async_client()
|
||||
|
||||
return chat_db
|
||||
|
||||
|
||||
def graph_db():
|
||||
if db.is_connected():
|
||||
from backend.data import graph as _graph_db
|
||||
|
||||
graph_db = _graph_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
graph_db = get_database_manager_async_client()
|
||||
|
||||
return graph_db
|
||||
|
||||
|
||||
def library_db():
|
||||
if db.is_connected():
|
||||
from backend.api.features.library import db as _library_db
|
||||
|
||||
library_db = _library_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
library_db = get_database_manager_async_client()
|
||||
|
||||
return library_db
|
||||
|
||||
|
||||
def store_db():
|
||||
if db.is_connected():
|
||||
from backend.api.features.store import db as _store_db
|
||||
|
||||
store_db = _store_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
store_db = get_database_manager_async_client()
|
||||
|
||||
return store_db
|
||||
|
||||
|
||||
def search():
|
||||
if db.is_connected():
|
||||
from backend.api.features.store import hybrid_search as _search
|
||||
|
||||
search = _search
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
search = get_database_manager_async_client()
|
||||
|
||||
return search
|
||||
|
||||
|
||||
def execution_db():
|
||||
if db.is_connected():
|
||||
from backend.data import execution as _execution_db
|
||||
|
||||
execution_db = _execution_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
execution_db = get_database_manager_async_client()
|
||||
|
||||
return execution_db
|
||||
|
||||
|
||||
def user_db():
|
||||
if db.is_connected():
|
||||
from backend.data import user as _user_db
|
||||
|
||||
user_db = _user_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
user_db = get_database_manager_async_client()
|
||||
|
||||
return user_db
|
||||
|
||||
|
||||
def understanding_db():
|
||||
if db.is_connected():
|
||||
from backend.data import understanding as _understanding_db
|
||||
|
||||
understanding_db = _understanding_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
understanding_db = get_database_manager_async_client()
|
||||
|
||||
return understanding_db
|
||||
|
||||
|
||||
def workspace_db():
|
||||
if db.is_connected():
|
||||
from backend.data import workspace as _workspace_db
|
||||
|
||||
workspace_db = _workspace_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
workspace_db = get_database_manager_async_client()
|
||||
|
||||
return workspace_db
|
||||
@@ -1,5 +1,5 @@
|
||||
from backend.app import run_processes
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from .database import DatabaseManager, DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
from .manager import ExecutionManager
|
||||
from .scheduler import Scheduler
|
||||
|
||||
__all__ = [
|
||||
"DatabaseManager",
|
||||
"DatabaseManagerClient",
|
||||
"DatabaseManagerAsyncClient",
|
||||
"ExecutionManager",
|
||||
"Scheduler",
|
||||
]
|
||||
|
||||
@@ -22,7 +22,7 @@ from backend.util.settings import Settings
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Redis-based distributed locking for cluster coordination."""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -20,7 +19,6 @@ class ClusterLock:
|
||||
self.owner_id = owner_id
|
||||
self.timeout = timeout
|
||||
self._last_refresh = 0.0
|
||||
self._refresh_lock = threading.Lock()
|
||||
|
||||
def try_acquire(self) -> str | None:
|
||||
"""Try to acquire the lock.
|
||||
@@ -33,8 +31,7 @@ class ClusterLock:
|
||||
try:
|
||||
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
|
||||
if success:
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = time.time()
|
||||
self._last_refresh = time.time()
|
||||
return self.owner_id # Successfully acquired
|
||||
|
||||
# Failed to acquire, get current owner
|
||||
@@ -60,27 +57,23 @@ class ClusterLock:
|
||||
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
||||
During rate limiting, still verifies lock existence but skips TTL extension.
|
||||
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
||||
|
||||
Thread-safe: uses _refresh_lock to protect _last_refresh access.
|
||||
"""
|
||||
# Calculate refresh interval: max(timeout // 10, 1)
|
||||
refresh_interval = max(self.timeout // 10, 1)
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we're within the rate limit period (thread-safe read)
|
||||
# Check if we're within the rate limit period
|
||||
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
||||
with self._refresh_lock:
|
||||
last_refresh = self._last_refresh
|
||||
is_rate_limited = (
|
||||
last_refresh > 0 and (current_time - last_refresh) < refresh_interval
|
||||
self._last_refresh > 0
|
||||
and (current_time - self._last_refresh) < refresh_interval
|
||||
)
|
||||
|
||||
try:
|
||||
# Always verify lock existence, even during rate limiting
|
||||
current_value = self.redis.get(self.key)
|
||||
if not current_value:
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
stored_owner = (
|
||||
@@ -89,8 +82,7 @@ class ClusterLock:
|
||||
else str(current_value)
|
||||
)
|
||||
if stored_owner != self.owner_id:
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
# If rate limited, return True but don't update TTL or timestamp
|
||||
@@ -99,30 +91,25 @@ class ClusterLock:
|
||||
|
||||
# Perform actual refresh
|
||||
if self.redis.expire(self.key, self.timeout):
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = current_time
|
||||
self._last_refresh = current_time
|
||||
return True
|
||||
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
def release(self):
|
||||
"""Release the lock."""
|
||||
with self._refresh_lock:
|
||||
if self._last_refresh == 0:
|
||||
return
|
||||
if self._last_refresh == 0:
|
||||
return
|
||||
|
||||
try:
|
||||
self.redis.delete(self.key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0.0
|
||||
self._last_refresh = 0.0
|
||||
|
||||
@@ -4,26 +4,14 @@ from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cas
|
||||
|
||||
from backend.api.features.library.db import (
|
||||
add_store_agent_to_library,
|
||||
create_graph_in_library,
|
||||
create_library_agent,
|
||||
get_library_agent,
|
||||
get_library_agent_by_graph_id,
|
||||
list_library_agents,
|
||||
update_graph_in_library,
|
||||
)
|
||||
from backend.api.features.store.db import (
|
||||
get_agent,
|
||||
get_available_graph,
|
||||
get_store_agent_details,
|
||||
get_store_agents,
|
||||
)
|
||||
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.api.features.store.embeddings import (
|
||||
backfill_missing_embeddings,
|
||||
cleanup_orphaned_embeddings,
|
||||
get_embedding_stats,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.copilot import db as chat_db
|
||||
from backend.data import db
|
||||
from backend.data.analytics import (
|
||||
get_accuracy_trends_and_alerts,
|
||||
@@ -60,7 +48,6 @@ from backend.data.graph import (
|
||||
get_graph_metadata,
|
||||
get_graph_settings,
|
||||
get_node,
|
||||
get_store_listed_graphs,
|
||||
validate_graph_execution_permissions,
|
||||
)
|
||||
from backend.data.human_review import (
|
||||
@@ -80,10 +67,6 @@ from backend.data.notifications import (
|
||||
remove_notifications_from_batch,
|
||||
)
|
||||
from backend.data.onboarding import increment_onboarding_runs
|
||||
from backend.data.understanding import (
|
||||
get_business_understanding,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_active_user_ids_in_timerange,
|
||||
get_user_by_id,
|
||||
@@ -93,7 +76,6 @@ from backend.data.user import (
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
@@ -125,13 +107,6 @@ async def _get_credits(user_id: str) -> int:
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
"""Database connection pooling service.
|
||||
|
||||
This service connects to the Prisma engine and exposes database
|
||||
operations via RPC endpoints. It acts as a centralized connection pool
|
||||
for all services that need database access.
|
||||
"""
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, app: "FastAPI"):
|
||||
async with super().lifespan(app):
|
||||
@@ -167,15 +142,11 @@ class DatabaseManager(AppService):
|
||||
def _(
|
||||
f: Callable[P, R], name: str | None = None
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
"""
|
||||
Exposes a function as an RPC endpoint, and adds a virtual `self` param
|
||||
to the function's type so it can be bound as a method.
|
||||
"""
|
||||
if name is not None:
|
||||
f.__name__ = name
|
||||
return cast(Callable[Concatenate[object, P], R], expose(f))
|
||||
|
||||
# ============ Graph Executions ============ #
|
||||
# Executions
|
||||
get_child_graph_executions = _(get_child_graph_executions)
|
||||
get_graph_executions = _(get_graph_executions)
|
||||
get_graph_executions_count = _(get_graph_executions_count)
|
||||
@@ -199,37 +170,36 @@ class DatabaseManager(AppService):
|
||||
get_frequently_executed_graphs = _(get_frequently_executed_graphs)
|
||||
get_marketplace_graphs_for_monitoring = _(get_marketplace_graphs_for_monitoring)
|
||||
|
||||
# ============ Graphs ============ #
|
||||
# Graphs
|
||||
get_node = _(get_node)
|
||||
get_graph = _(get_graph)
|
||||
get_connected_output_nodes = _(get_connected_output_nodes)
|
||||
get_graph_metadata = _(get_graph_metadata)
|
||||
get_graph_settings = _(get_graph_settings)
|
||||
get_store_listed_graphs = _(get_store_listed_graphs)
|
||||
|
||||
# ============ Credits ============ #
|
||||
# Credits
|
||||
spend_credits = _(_spend_credits, name="spend_credits")
|
||||
get_credits = _(_get_credits, name="get_credits")
|
||||
|
||||
# ============ User + Integrations ============ #
|
||||
get_user_by_id = _(get_user_by_id)
|
||||
# User + User Metadata + User Integrations
|
||||
get_user_integrations = _(get_user_integrations)
|
||||
update_user_integrations = _(update_user_integrations)
|
||||
|
||||
# ============ User Comms ============ #
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
||||
get_user_by_id = _(get_user_by_id)
|
||||
get_user_email_by_id = _(get_user_email_by_id)
|
||||
get_user_email_verification = _(get_user_email_verification)
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
|
||||
# ============ Human In The Loop ============ #
|
||||
# Human In The Loop
|
||||
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
|
||||
check_approval = _(check_approval)
|
||||
get_or_create_human_review = _(get_or_create_human_review)
|
||||
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
|
||||
update_review_processed_status = _(update_review_processed_status)
|
||||
|
||||
# ============ Notifications ============ #
|
||||
# Notifications - async
|
||||
clear_all_user_notification_batches = _(clear_all_user_notification_batches)
|
||||
create_or_add_to_user_notification_batch = _(
|
||||
create_or_add_to_user_notification_batch
|
||||
@@ -242,56 +212,29 @@ class DatabaseManager(AppService):
|
||||
get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# ============ Library ============ #
|
||||
# Library
|
||||
list_library_agents = _(list_library_agents)
|
||||
add_store_agent_to_library = _(add_store_agent_to_library)
|
||||
create_graph_in_library = _(create_graph_in_library)
|
||||
create_library_agent = _(create_library_agent)
|
||||
get_library_agent = _(get_library_agent)
|
||||
get_library_agent_by_graph_id = _(get_library_agent_by_graph_id)
|
||||
update_graph_in_library = _(update_graph_in_library)
|
||||
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
|
||||
|
||||
# ============ Onboarding ============ #
|
||||
# Onboarding
|
||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||
|
||||
# ============ OAuth ============ #
|
||||
# OAuth
|
||||
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
|
||||
|
||||
# ============ Store ============ #
|
||||
# Store
|
||||
get_store_agents = _(get_store_agents)
|
||||
get_store_agent_details = _(get_store_agent_details)
|
||||
get_agent = _(get_agent)
|
||||
get_available_graph = _(get_available_graph)
|
||||
|
||||
# ============ Search ============ #
|
||||
# Store Embeddings
|
||||
get_embedding_stats = _(get_embedding_stats)
|
||||
backfill_missing_embeddings = _(backfill_missing_embeddings)
|
||||
cleanup_orphaned_embeddings = _(cleanup_orphaned_embeddings)
|
||||
unified_hybrid_search = _(unified_hybrid_search)
|
||||
|
||||
# ============ Summary Data ============ #
|
||||
# Summary data - async
|
||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||
|
||||
# ============ Workspace ============ #
|
||||
get_or_create_workspace = _(get_or_create_workspace)
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = _(get_business_understanding)
|
||||
upsert_business_understanding = _(upsert_business_understanding)
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_session = _(chat_db.get_chat_session)
|
||||
create_chat_session = _(chat_db.create_chat_session)
|
||||
update_chat_session = _(chat_db.update_chat_session)
|
||||
add_chat_message = _(chat_db.add_chat_message)
|
||||
add_chat_messages_batch = _(chat_db.add_chat_messages_batch)
|
||||
get_user_chat_sessions = _(chat_db.get_user_chat_sessions)
|
||||
get_user_session_count = _(chat_db.get_user_session_count)
|
||||
delete_chat_session = _(chat_db.delete_chat_session)
|
||||
get_chat_session_message_count = _(chat_db.get_chat_session_message_count)
|
||||
update_tool_message_content = _(chat_db.update_tool_message_content)
|
||||
|
||||
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -353,50 +296,43 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
def get_service_type(cls):
|
||||
return DatabaseManager
|
||||
|
||||
# ============ Graph Executions ============ #
|
||||
create_graph_execution = d.create_graph_execution
|
||||
get_child_graph_executions = d.get_child_graph_executions
|
||||
get_connected_output_nodes = d.get_connected_output_nodes
|
||||
get_latest_node_execution = d.get_latest_node_execution
|
||||
get_graph_execution = d.get_graph_execution
|
||||
get_graph_execution_meta = d.get_graph_execution_meta
|
||||
get_graph_executions = d.get_graph_executions
|
||||
get_node_execution = d.get_node_execution
|
||||
get_node_executions = d.get_node_executions
|
||||
update_graph_execution_stats = d.update_graph_execution_stats
|
||||
update_node_execution_status = d.update_node_execution_status
|
||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
|
||||
get_execution_kv_data = d.get_execution_kv_data
|
||||
set_execution_kv_data = d.set_execution_kv_data
|
||||
|
||||
# ============ Graphs ============ #
|
||||
get_graph = d.get_graph
|
||||
get_graph_metadata = d.get_graph_metadata
|
||||
get_graph_settings = d.get_graph_settings
|
||||
get_graph_execution = d.get_graph_execution
|
||||
get_graph_execution_meta = d.get_graph_execution_meta
|
||||
get_node = d.get_node
|
||||
get_store_listed_graphs = d.get_store_listed_graphs
|
||||
|
||||
# ============ User + Integrations ============ #
|
||||
get_node_execution = d.get_node_execution
|
||||
get_node_executions = d.get_node_executions
|
||||
get_user_by_id = d.get_user_by_id
|
||||
get_user_integrations = d.get_user_integrations
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
|
||||
update_graph_execution_stats = d.update_graph_execution_stats
|
||||
update_node_execution_status = d.update_node_execution_status
|
||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||
update_user_integrations = d.update_user_integrations
|
||||
get_execution_kv_data = d.get_execution_kv_data
|
||||
set_execution_kv_data = d.set_execution_kv_data
|
||||
|
||||
# ============ Human In The Loop ============ #
|
||||
# Human In The Loop
|
||||
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
|
||||
check_approval = d.check_approval
|
||||
get_or_create_human_review = d.get_or_create_human_review
|
||||
update_review_processed_status = d.update_review_processed_status
|
||||
|
||||
# ============ User Comms ============ #
|
||||
# User Comms
|
||||
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
|
||||
get_user_email_by_id = d.get_user_email_by_id
|
||||
get_user_email_verification = d.get_user_email_verification
|
||||
get_user_notification_preference = d.get_user_notification_preference
|
||||
|
||||
# ============ Notifications ============ #
|
||||
# Notifications
|
||||
clear_all_user_notification_batches = d.clear_all_user_notification_batches
|
||||
create_or_add_to_user_notification_batch = (
|
||||
d.create_or_add_to_user_notification_batch
|
||||
@@ -409,49 +345,20 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# ============ Library ============ #
|
||||
# Library
|
||||
list_library_agents = d.list_library_agents
|
||||
add_store_agent_to_library = d.add_store_agent_to_library
|
||||
create_graph_in_library = d.create_graph_in_library
|
||||
create_library_agent = d.create_library_agent
|
||||
get_library_agent = d.get_library_agent
|
||||
get_library_agent_by_graph_id = d.get_library_agent_by_graph_id
|
||||
update_graph_in_library = d.update_graph_in_library
|
||||
validate_graph_execution_permissions = d.validate_graph_execution_permissions
|
||||
|
||||
# ============ Onboarding ============ #
|
||||
# Onboarding
|
||||
increment_onboarding_runs = d.increment_onboarding_runs
|
||||
|
||||
# ============ OAuth ============ #
|
||||
# OAuth
|
||||
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
|
||||
|
||||
# ============ Store ============ #
|
||||
# Store
|
||||
get_store_agents = d.get_store_agents
|
||||
get_store_agent_details = d.get_store_agent_details
|
||||
get_agent = d.get_agent
|
||||
get_available_graph = d.get_available_graph
|
||||
|
||||
# ============ Search ============ #
|
||||
unified_hybrid_search = d.unified_hybrid_search
|
||||
|
||||
# ============ Summary Data ============ #
|
||||
# Summary data
|
||||
get_user_execution_summary_data = d.get_user_execution_summary_data
|
||||
|
||||
# ============ Workspace ============ #
|
||||
get_or_create_workspace = d.get_or_create_workspace
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = d.get_business_understanding
|
||||
upsert_business_understanding = d.upsert_business_understanding
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_session = d.get_chat_session
|
||||
create_chat_session = d.create_chat_session
|
||||
update_chat_session = d.update_chat_session
|
||||
add_chat_message = d.add_chat_message
|
||||
add_chat_messages_batch = d.add_chat_messages_batch
|
||||
get_user_chat_sessions = d.get_user_chat_sessions
|
||||
get_user_session_count = d.get_user_session_count
|
||||
delete_chat_session = d.delete_chat_session
|
||||
get_chat_session_message_count = d.get_chat_session_message_count
|
||||
update_tool_message_content = d.update_tool_message_content
|
||||
@@ -92,10 +92,7 @@ from .utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.db_manager import (
|
||||
DatabaseManagerAsyncClient,
|
||||
DatabaseManagerClient,
|
||||
)
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -13,15 +13,12 @@ if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
from supabase import AClient, Client
|
||||
|
||||
from backend.data.db_manager import (
|
||||
DatabaseManagerAsyncClient,
|
||||
DatabaseManagerClient,
|
||||
)
|
||||
from backend.data.execution import (
|
||||
AsyncRedisExecutionEventBus,
|
||||
RedisExecutionEventBus,
|
||||
)
|
||||
from backend.data.rabbitmq import AsyncRabbitMQ, SyncRabbitMQ
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
from backend.executor.scheduler import SchedulerClient
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
@@ -30,7 +27,7 @@ if TYPE_CHECKING:
|
||||
@thread_cached
|
||||
def get_database_manager_client() -> "DatabaseManagerClient":
|
||||
"""Get a thread-cached DatabaseManagerClient with request retry enabled."""
|
||||
from backend.data.db_manager import DatabaseManagerClient
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerClient, request_retry=True)
|
||||
@@ -41,7 +38,7 @@ def get_database_manager_async_client(
|
||||
should_retry: bool = True,
|
||||
) -> "DatabaseManagerAsyncClient":
|
||||
"""Get a thread-cached DatabaseManagerAsyncClient with request retry enabled."""
|
||||
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, request_retry=should_retry)
|
||||
@@ -109,20 +106,6 @@ async def get_async_execution_queue() -> "AsyncRabbitMQ":
|
||||
return client
|
||||
|
||||
|
||||
# ============ CoPilot Queue Helpers ============ #
|
||||
|
||||
|
||||
@thread_cached
|
||||
async def get_async_copilot_queue() -> "AsyncRabbitMQ":
|
||||
"""Get a thread-cached AsyncRabbitMQ CoPilot queue client."""
|
||||
from backend.copilot.executor.utils import create_copilot_queue_config
|
||||
from backend.data.rabbitmq import AsyncRabbitMQ
|
||||
|
||||
client = AsyncRabbitMQ(create_copilot_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
# ============ Integration Credentials Store ============ #
|
||||
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ class Flag(str, Enum):
|
||||
AGENT_ACTIVITY = "agent-activity"
|
||||
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
|
||||
CHAT = "chat"
|
||||
COPILOT_SDK = "copilot-sdk"
|
||||
|
||||
|
||||
def is_configured() -> bool:
|
||||
|
||||
@@ -211,23 +211,16 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The port for execution manager daemon to run on",
|
||||
)
|
||||
|
||||
num_copilot_workers: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Number of concurrent CoPilot executor workers",
|
||||
)
|
||||
|
||||
copilot_executor_port: int = Field(
|
||||
default=8008,
|
||||
description="The port for CoPilot executor daemon to run on",
|
||||
)
|
||||
|
||||
execution_scheduler_port: int = Field(
|
||||
default=8003,
|
||||
description="The port for execution scheduler daemon to run on",
|
||||
)
|
||||
|
||||
agent_server_port: int = Field(
|
||||
default=8004,
|
||||
description="The port for agent server daemon to run on",
|
||||
)
|
||||
|
||||
database_api_port: int = Field(
|
||||
default=8005,
|
||||
description="The port for database server API to run on",
|
||||
|
||||
@@ -11,7 +11,6 @@ from backend.api.rest_api import AgentServer
|
||||
from backend.blocks._base import Block, BlockSchema
|
||||
from backend.data import db
|
||||
from backend.data.block import initialize_blocks
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
from backend.data.execution import (
|
||||
ExecutionContext,
|
||||
ExecutionStatus,
|
||||
@@ -20,7 +19,7 @@ from backend.data.execution import (
|
||||
)
|
||||
from backend.data.model import _BaseCredentials
|
||||
from backend.data.user import create_default_user
|
||||
from backend.executor import ExecutionManager, Scheduler
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
94
autogpt_platform/backend/poetry.lock
generated
94
autogpt_platform/backend/poetry.lock
generated
@@ -897,6 +897,29 @@ files = [
|
||||
{file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "claude-agent-sdk"
|
||||
version = "0.1.35"
|
||||
description = "Python SDK for Claude Code"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "claude_agent_sdk-0.1.35-py3-none-macosx_11_0_arm64.whl", hash = "sha256:df67f4deade77b16a9678b3a626c176498e40417f33b04beda9628287f375591"},
|
||||
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:14963944f55ded7c8ed518feebfa5b4284aa6dd8d81aeff2e5b21a962ce65097"},
|
||||
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:84344dcc535d179c1fc8a11c6f34c37c3b583447bdf09d869effb26514fd7a65"},
|
||||
{file = "claude_agent_sdk-0.1.35-py3-none-win_amd64.whl", hash = "sha256:1b3d54b47448c93f6f372acd4d1757f047c3c1e8ef5804be7a1e3e53e2c79a5f"},
|
||||
{file = "claude_agent_sdk-0.1.35.tar.gz", hash = "sha256:0f98e2b3c71ca85abfc042e7a35c648df88e87fda41c52e6779ef7b038dcbb52"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=4.0.0"
|
||||
mcp = ">=0.1.0"
|
||||
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
dev = ["anyio[trio] (>=4.0.0)", "mypy (>=1.0.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.20.0)", "pytest-cov (>=4.0.0)", "ruff (>=0.1.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "cleo"
|
||||
version = "2.1.0"
|
||||
@@ -2593,6 +2616,18 @@ http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "httpx-sse"
|
||||
version = "0.4.3"
|
||||
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"},
|
||||
{file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "1.4.1"
|
||||
@@ -3310,6 +3345,39 @@ files = [
|
||||
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mcp"
|
||||
version = "1.26.0"
|
||||
description = "Model Context Protocol SDK"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca"},
|
||||
{file = "mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=4.5"
|
||||
httpx = ">=0.27.1"
|
||||
httpx-sse = ">=0.4"
|
||||
jsonschema = ">=4.20.0"
|
||||
pydantic = ">=2.11.0,<3.0.0"
|
||||
pydantic-settings = ">=2.5.2"
|
||||
pyjwt = {version = ">=2.10.1", extras = ["crypto"]}
|
||||
python-multipart = ">=0.0.9"
|
||||
pywin32 = {version = ">=310", markers = "sys_platform == \"win32\""}
|
||||
sse-starlette = ">=1.6.1"
|
||||
starlette = ">=0.27"
|
||||
typing-extensions = ">=4.9.0"
|
||||
typing-inspection = ">=0.4.1"
|
||||
uvicorn = {version = ">=0.31.1", markers = "sys_platform != \"emscripten\""}
|
||||
|
||||
[package.extras]
|
||||
cli = ["python-dotenv (>=1.0.0)", "typer (>=0.16.0)"]
|
||||
rich = ["rich (>=13.9.4)"]
|
||||
ws = ["websockets (>=15.0.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "mdurl"
|
||||
version = "0.1.2"
|
||||
@@ -5994,7 +6062,7 @@ description = "Python for Window Extensions"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
markers = "platform_system == \"Windows\""
|
||||
markers = "sys_platform == \"win32\" or platform_system == \"Windows\""
|
||||
files = [
|
||||
{file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"},
|
||||
{file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"},
|
||||
@@ -6974,6 +7042,28 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
|
||||
pymysql = ["pymysql"]
|
||||
sqlcipher = ["sqlcipher3_binary"]
|
||||
|
||||
[[package]]
|
||||
name = "sse-starlette"
|
||||
version = "3.2.0"
|
||||
description = "SSE plugin for Starlette"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "sse_starlette-3.2.0-py3-none-any.whl", hash = "sha256:5876954bd51920fc2cd51baee47a080eb88a37b5b784e615abb0b283f801cdbf"},
|
||||
{file = "sse_starlette-3.2.0.tar.gz", hash = "sha256:8127594edfb51abe44eac9c49e59b0b01f1039d0c7461c6fd91d4e03b70da422"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=4.7.0"
|
||||
starlette = ">=0.49.1"
|
||||
|
||||
[package.extras]
|
||||
daphne = ["daphne (>=4.2.0)"]
|
||||
examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio] (>=2.0.41)", "uvicorn (>=0.34.0)"]
|
||||
granian = ["granian (>=2.3.1)"]
|
||||
uvicorn = ["uvicorn (>=0.34.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "stagehand"
|
||||
version = "0.5.9"
|
||||
@@ -8440,4 +8530,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "fa9c5deadf593e815dd2190f58e22152373900603f5f244b9616cd721de84d2f"
|
||||
content-hash = "55e095de555482f0fe47de7695f390fe93e7bcf739b31c391b2e5e3c3d938ae3"
|
||||
|
||||
@@ -16,6 +16,7 @@ anthropic = "^0.79.0"
|
||||
apscheduler = "^3.11.1"
|
||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||
claude-agent-sdk = "^0.1.0"
|
||||
click = "^8.2.0"
|
||||
cryptography = "^46.0"
|
||||
discord-py = "^2.5.2"
|
||||
@@ -116,7 +117,6 @@ ws = "backend.ws:main"
|
||||
scheduler = "backend.scheduler:main"
|
||||
notification = "backend.notification:main"
|
||||
executor = "backend.exec:main"
|
||||
copilot-executor = "backend.copilot.executor.__main__:main"
|
||||
cli = "backend.cli:main"
|
||||
format = "linter:format"
|
||||
lint = "linter:lint"
|
||||
|
||||
@@ -9,8 +9,10 @@ from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools.agent_generator import core
|
||||
from backend.copilot.tools.agent_generator.core import AgentGeneratorNotConfiguredError
|
||||
from backend.api.features.chat.tools.agent_generator import core
|
||||
from backend.api.features.chat.tools.agent_generator.core import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
)
|
||||
|
||||
|
||||
class TestServiceNotConfigured:
|
||||
|
||||
@@ -9,7 +9,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools.agent_generator import core
|
||||
from backend.api.features.chat.tools.agent_generator import core
|
||||
|
||||
|
||||
class TestGetLibraryAgentsForGeneration:
|
||||
@@ -31,20 +31,18 @@ class TestGetLibraryAgentsForGeneration:
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = [mock_agent]
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"library_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_list:
|
||||
result = await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="send email",
|
||||
)
|
||||
|
||||
mock_db.list_library_agents.assert_called_once_with(
|
||||
mock_list.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
search_term="send email",
|
||||
page=1,
|
||||
@@ -82,13 +80,11 @@ class TestGetLibraryAgentsForGeneration:
|
||||
),
|
||||
]
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"library_db",
|
||||
return_value=mock_db,
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
result = await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
@@ -105,20 +101,18 @@ class TestGetLibraryAgentsForGeneration:
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = []
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"library_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_list:
|
||||
await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
max_results=5,
|
||||
)
|
||||
|
||||
mock_db.list_library_agents.assert_called_once_with(
|
||||
mock_list.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
search_term=None,
|
||||
page=1,
|
||||
@@ -150,24 +144,24 @@ class TestSearchMarketplaceAgentsForGeneration:
|
||||
mock_graph.input_schema = {"type": "object"}
|
||||
mock_graph.output_schema = {"type": "object"}
|
||||
|
||||
mock_store_db = MagicMock()
|
||||
mock_store_db.get_store_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
mock_graph_db = MagicMock()
|
||||
mock_graph_db.get_store_listed_graphs = AsyncMock(
|
||||
return_value={"graph-123": mock_graph}
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(core, "store_db", return_value=mock_store_db),
|
||||
patch.object(core, "graph_db", return_value=mock_graph_db),
|
||||
patch(
|
||||
"backend.api.features.store.db.get_store_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_search,
|
||||
patch(
|
||||
"backend.api.features.chat.tools.agent_generator.core.get_store_listed_graphs",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"graph-123": mock_graph},
|
||||
),
|
||||
):
|
||||
result = await core.search_marketplace_agents_for_generation(
|
||||
search_query="automation",
|
||||
max_results=10,
|
||||
)
|
||||
|
||||
mock_store_db.get_store_agents.assert_called_once_with(
|
||||
mock_search.assert_called_once_with(
|
||||
search_query="automation",
|
||||
page=1,
|
||||
page_size=10,
|
||||
@@ -713,7 +707,7 @@ class TestExtractUuidsFromText:
|
||||
|
||||
|
||||
class TestGetLibraryAgentById:
|
||||
"""Test get_library_agent_by_id function (alias: get_library_agent_by_graph_id)."""
|
||||
"""Test get_library_agent_by_id function (and its alias get_library_agent_by_graph_id)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_agent_when_found_by_graph_id(self):
|
||||
@@ -726,10 +720,12 @@ class TestGetLibraryAgentById:
|
||||
mock_agent.input_schema = {"properties": {}}
|
||||
mock_agent.output_schema = {"properties": {}}
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent,
|
||||
):
|
||||
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||
|
||||
assert result is not None
|
||||
@@ -747,11 +743,20 @@ class TestGetLibraryAgentById:
|
||||
mock_agent.input_schema = {"properties": {}}
|
||||
mock_agent.output_schema = {"properties": {}}
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=None)
|
||||
mock_db.get_library_agent = AsyncMock(return_value=mock_agent)
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None, # Not found by graph_id
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent, # Found by library ID
|
||||
),
|
||||
):
|
||||
result = await core.get_library_agent_by_id("user-123", "library-id-123")
|
||||
|
||||
assert result is not None
|
||||
@@ -761,13 +766,20 @@ class TestGetLibraryAgentById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_not_found_by_either_method(self):
|
||||
"""Test that None is returned when agent not found by either method."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=None)
|
||||
mock_db.get_library_agent = AsyncMock(
|
||||
side_effect=core.NotFoundError("Not found")
|
||||
)
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=core.NotFoundError("Not found"),
|
||||
),
|
||||
):
|
||||
result = await core.get_library_agent_by_id("user-123", "nonexistent")
|
||||
|
||||
assert result is None
|
||||
@@ -775,20 +787,27 @@ class TestGetLibraryAgentById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_exception(self):
|
||||
"""Test that None is returned when exception occurs in both lookups."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(
|
||||
side_effect=Exception("Database error")
|
||||
)
|
||||
mock_db.get_library_agent = AsyncMock(side_effect=Exception("Database error"))
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database error"),
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database error"),
|
||||
),
|
||||
):
|
||||
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alias_works(self):
|
||||
"""Test that get_library_agent_by_graph_id is an alias."""
|
||||
"""Test that get_library_agent_by_graph_id is an alias for get_library_agent_by_id."""
|
||||
assert core.get_library_agent_by_graph_id is core.get_library_agent_by_id
|
||||
|
||||
|
||||
@@ -809,11 +828,20 @@ class TestGetAllRelevantAgentsWithUuids:
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = []
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent,
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d",
|
||||
|
||||
@@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools.agent_generator import service
|
||||
from backend.api.features.chat.tools.agent_generator import service
|
||||
|
||||
|
||||
class TestServiceConfiguration:
|
||||
|
||||
0
autogpt_platform/backend/test/chat/__init__.py
Normal file
0
autogpt_platform/backend/test/chat/__init__.py
Normal file
133
autogpt_platform/backend/test/chat/test_security_hooks.py
Normal file
133
autogpt_platform/backend/test/chat/test_security_hooks.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Tests for SDK security hooks — workspace paths, tool access, and deny messages.
|
||||
|
||||
These are pure unit tests with no external dependencies (no SDK, no DB, no server).
|
||||
They validate that the security hooks correctly block unauthorized paths,
|
||||
tool access, and dangerous input patterns.
|
||||
|
||||
Note: Bash command validation was removed — the SDK built-in Bash tool is not in
|
||||
allowed_tools, and the bash_exec MCP tool has kernel-level network isolation
|
||||
(unshare --net) making command-level parsing unnecessary.
|
||||
"""
|
||||
|
||||
from backend.api.features.chat.sdk.security_hooks import (
|
||||
_validate_tool_access,
|
||||
_validate_workspace_path,
|
||||
)
|
||||
|
||||
SDK_CWD = "/tmp/copilot-test-session"
|
||||
|
||||
|
||||
def _is_denied(result: dict) -> bool:
|
||||
hook = result.get("hookSpecificOutput", {})
|
||||
return hook.get("permissionDecision") == "deny"
|
||||
|
||||
|
||||
def _reason(result: dict) -> str:
|
||||
return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Workspace path validation (Read, Write, Edit, etc.)
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestWorkspacePathValidation:
|
||||
def test_path_in_workspace(self):
|
||||
result = _validate_workspace_path(
|
||||
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
def test_path_outside_workspace(self):
|
||||
result = _validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD)
|
||||
assert _is_denied(result)
|
||||
|
||||
def test_tool_results_allowed(self):
|
||||
result = _validate_workspace_path(
|
||||
"Read",
|
||||
{"file_path": "~/.claude/projects/abc/tool-results/out.txt"},
|
||||
SDK_CWD,
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
def test_claude_settings_blocked(self):
|
||||
result = _validate_workspace_path(
|
||||
"Read", {"file_path": "~/.claude/settings.json"}, SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
def test_claude_projects_without_tool_results(self):
|
||||
result = _validate_workspace_path(
|
||||
"Read", {"file_path": "~/.claude/projects/abc/credentials.json"}, SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
def test_no_path_allowed(self):
|
||||
"""Glob/Grep without path defaults to cwd — should be allowed."""
|
||||
result = _validate_workspace_path("Grep", {"pattern": "foo"}, SDK_CWD)
|
||||
assert not _is_denied(result)
|
||||
|
||||
def test_path_traversal_with_dotdot(self):
|
||||
result = _validate_workspace_path(
|
||||
"Read", {"file_path": f"{SDK_CWD}/../../../etc/passwd"}, SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Tool access validation
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestToolAccessValidation:
|
||||
def test_blocked_tools(self):
|
||||
for tool in ("bash", "shell", "exec", "terminal", "command"):
|
||||
result = _validate_tool_access(tool, {})
|
||||
assert _is_denied(result), f"Tool '{tool}' should be blocked"
|
||||
|
||||
def test_bash_builtin_blocked(self):
|
||||
"""SDK built-in Bash (capital) is blocked as defence-in-depth."""
|
||||
result = _validate_tool_access("Bash", {"command": "echo hello"}, SDK_CWD)
|
||||
assert _is_denied(result)
|
||||
assert "Bash" in _reason(result)
|
||||
|
||||
def test_workspace_tools_delegate(self):
|
||||
result = _validate_tool_access(
|
||||
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
def test_dangerous_pattern_blocked(self):
|
||||
result = _validate_tool_access("SomeUnknownTool", {"data": "sudo rm -rf /"})
|
||||
assert _is_denied(result)
|
||||
|
||||
def test_safe_unknown_tool_allowed(self):
|
||||
result = _validate_tool_access("SomeSafeTool", {"data": "hello world"})
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Deny message quality (ntindle feedback)
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestDenyMessageClarity:
|
||||
"""Deny messages must include [SECURITY] and 'cannot be bypassed'
|
||||
so the model knows the restriction is enforced, not a suggestion."""
|
||||
|
||||
def test_blocked_tool_message(self):
|
||||
reason = _reason(_validate_tool_access("bash", {}))
|
||||
assert "[SECURITY]" in reason
|
||||
assert "cannot be bypassed" in reason
|
||||
|
||||
def test_bash_builtin_blocked_message(self):
|
||||
reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"}))
|
||||
assert "[SECURITY]" in reason
|
||||
assert "cannot be bypassed" in reason
|
||||
|
||||
def test_workspace_path_message(self):
|
||||
reason = _reason(
|
||||
_validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD)
|
||||
)
|
||||
assert "[SECURITY]" in reason
|
||||
assert "cannot be bypassed" in reason
|
||||
255
autogpt_platform/backend/test/chat/test_transcript.py
Normal file
255
autogpt_platform/backend/test/chat/test_transcript.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Unit tests for JSONL transcript management utilities."""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from backend.api.features.chat.sdk.transcript import (
|
||||
STRIPPABLE_TYPES,
|
||||
read_transcript_file,
|
||||
strip_progress_entries,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
|
||||
|
||||
def _make_jsonl(*entries: dict) -> str:
|
||||
return "\n".join(json.dumps(e) for e in entries) + "\n"
|
||||
|
||||
|
||||
# --- Fixtures ---
|
||||
|
||||
|
||||
METADATA_LINE = {"type": "queue-operation", "subtype": "create"}
|
||||
FILE_HISTORY = {"type": "file-history-snapshot", "files": []}
|
||||
USER_MSG = {"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}}
|
||||
ASST_MSG = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {"role": "assistant", "content": "hello"},
|
||||
}
|
||||
PROGRESS_ENTRY = {
|
||||
"type": "progress",
|
||||
"uuid": "p1",
|
||||
"parentUuid": "u1",
|
||||
"data": {"type": "bash_progress", "stdout": "running..."},
|
||||
}
|
||||
|
||||
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
|
||||
|
||||
|
||||
# --- read_transcript_file ---
|
||||
|
||||
|
||||
class TestReadTranscriptFile:
|
||||
def test_returns_content_for_valid_file(self, tmp_path):
|
||||
path = tmp_path / "session.jsonl"
|
||||
path.write_text(VALID_TRANSCRIPT)
|
||||
result = read_transcript_file(str(path))
|
||||
assert result is not None
|
||||
assert "user" in result
|
||||
|
||||
def test_returns_none_for_missing_file(self):
|
||||
assert read_transcript_file("/nonexistent/path.jsonl") is None
|
||||
|
||||
def test_returns_none_for_empty_path(self):
|
||||
assert read_transcript_file("") is None
|
||||
|
||||
def test_returns_none_for_empty_file(self, tmp_path):
|
||||
path = tmp_path / "empty.jsonl"
|
||||
path.write_text("")
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_returns_none_for_metadata_only(self, tmp_path):
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
|
||||
path = tmp_path / "meta.jsonl"
|
||||
path.write_text(content)
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_returns_none_for_invalid_json(self, tmp_path):
|
||||
path = tmp_path / "bad.jsonl"
|
||||
path.write_text("not json\n{}\n{}\n")
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_no_size_limit(self, tmp_path):
|
||||
"""Large files are accepted — bucket storage has no size limit."""
|
||||
big_content = {"type": "user", "uuid": "u9", "data": "x" * 1_000_000}
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, big_content, ASST_MSG)
|
||||
path = tmp_path / "big.jsonl"
|
||||
path.write_text(content)
|
||||
result = read_transcript_file(str(path))
|
||||
assert result is not None
|
||||
|
||||
|
||||
# --- write_transcript_to_tempfile ---
|
||||
|
||||
|
||||
class TestWriteTranscriptToTempfile:
|
||||
"""Tests use /tmp/copilot-* paths to satisfy the sandbox prefix check."""
|
||||
|
||||
def test_writes_file_and_returns_path(self):
|
||||
cwd = "/tmp/copilot-test-write"
|
||||
try:
|
||||
result = write_transcript_to_tempfile(
|
||||
VALID_TRANSCRIPT, "sess-1234-abcd", cwd
|
||||
)
|
||||
assert result is not None
|
||||
assert os.path.isfile(result)
|
||||
assert result.endswith(".jsonl")
|
||||
with open(result) as f:
|
||||
assert f.read() == VALID_TRANSCRIPT
|
||||
finally:
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(cwd, ignore_errors=True)
|
||||
|
||||
def test_creates_parent_directory(self):
|
||||
cwd = "/tmp/copilot-test-mkdir"
|
||||
try:
|
||||
result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd)
|
||||
assert result is not None
|
||||
assert os.path.isdir(cwd)
|
||||
finally:
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(cwd, ignore_errors=True)
|
||||
|
||||
def test_uses_session_id_prefix(self):
|
||||
cwd = "/tmp/copilot-test-prefix"
|
||||
try:
|
||||
result = write_transcript_to_tempfile(
|
||||
VALID_TRANSCRIPT, "abcdef12-rest", cwd
|
||||
)
|
||||
assert result is not None
|
||||
assert "abcdef12" in os.path.basename(result)
|
||||
finally:
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(cwd, ignore_errors=True)
|
||||
|
||||
def test_rejects_cwd_outside_sandbox(self, tmp_path):
|
||||
cwd = str(tmp_path / "not-copilot")
|
||||
result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd)
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- validate_transcript ---
|
||||
|
||||
|
||||
class TestValidateTranscript:
|
||||
def test_valid_transcript(self):
|
||||
assert validate_transcript(VALID_TRANSCRIPT) is True
|
||||
|
||||
def test_none_content(self):
|
||||
assert validate_transcript(None) is False
|
||||
|
||||
def test_empty_content(self):
|
||||
assert validate_transcript("") is False
|
||||
|
||||
def test_metadata_only(self):
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
|
||||
assert validate_transcript(content) is False
|
||||
|
||||
def test_user_only_no_assistant(self):
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG)
|
||||
assert validate_transcript(content) is False
|
||||
|
||||
def test_assistant_only_no_user(self):
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, ASST_MSG)
|
||||
assert validate_transcript(content) is False
|
||||
|
||||
def test_invalid_json_returns_false(self):
|
||||
assert validate_transcript("not json\n{}\n{}\n") is False
|
||||
|
||||
|
||||
# --- strip_progress_entries ---
|
||||
|
||||
|
||||
class TestStripProgressEntries:
|
||||
def test_strips_all_strippable_types(self):
|
||||
"""All STRIPPABLE_TYPES are removed from the output."""
|
||||
entries = [
|
||||
USER_MSG,
|
||||
{"type": "progress", "uuid": "p1", "parentUuid": "u1"},
|
||||
{"type": "file-history-snapshot", "files": []},
|
||||
{"type": "queue-operation", "subtype": "create"},
|
||||
{"type": "summary", "text": "..."},
|
||||
{"type": "pr-link", "url": "..."},
|
||||
ASST_MSG,
|
||||
]
|
||||
result = strip_progress_entries(_make_jsonl(*entries))
|
||||
result_types = {json.loads(line)["type"] for line in result.strip().split("\n")}
|
||||
assert result_types == {"user", "assistant"}
|
||||
for stype in STRIPPABLE_TYPES:
|
||||
assert stype not in result_types
|
||||
|
||||
def test_reparents_children_of_stripped_entries(self):
|
||||
"""An assistant message whose parent is a progress entry gets reparented."""
|
||||
progress = {
|
||||
"type": "progress",
|
||||
"uuid": "p1",
|
||||
"parentUuid": "u1",
|
||||
"data": {"type": "bash_progress"},
|
||||
}
|
||||
asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "p1", # Points to progress
|
||||
"message": {"role": "assistant", "content": "done"},
|
||||
}
|
||||
content = _make_jsonl(USER_MSG, progress, asst)
|
||||
result = strip_progress_entries(content)
|
||||
lines = [json.loads(line) for line in result.strip().split("\n")]
|
||||
|
||||
asst_entry = next(e for e in lines if e["type"] == "assistant")
|
||||
# Should be reparented to u1 (the user message)
|
||||
assert asst_entry["parentUuid"] == "u1"
|
||||
|
||||
def test_reparents_through_chain(self):
|
||||
"""Reparenting walks through multiple stripped entries."""
|
||||
p1 = {"type": "progress", "uuid": "p1", "parentUuid": "u1"}
|
||||
p2 = {"type": "progress", "uuid": "p2", "parentUuid": "p1"}
|
||||
p3 = {"type": "progress", "uuid": "p3", "parentUuid": "p2"}
|
||||
asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "p3", # 3 levels deep
|
||||
"message": {"role": "assistant", "content": "done"},
|
||||
}
|
||||
content = _make_jsonl(USER_MSG, p1, p2, p3, asst)
|
||||
result = strip_progress_entries(content)
|
||||
lines = [json.loads(line) for line in result.strip().split("\n")]
|
||||
|
||||
asst_entry = next(e for e in lines if e["type"] == "assistant")
|
||||
assert asst_entry["parentUuid"] == "u1"
|
||||
|
||||
def test_preserves_non_strippable_entries(self):
|
||||
"""User, assistant, and system entries are preserved."""
|
||||
system = {"type": "system", "uuid": "s1", "message": "prompt"}
|
||||
content = _make_jsonl(system, USER_MSG, ASST_MSG)
|
||||
result = strip_progress_entries(content)
|
||||
result_types = [json.loads(line)["type"] for line in result.strip().split("\n")]
|
||||
assert result_types == ["system", "user", "assistant"]
|
||||
|
||||
def test_empty_input(self):
|
||||
result = strip_progress_entries("")
|
||||
# Should return just a newline (empty content stripped)
|
||||
assert result.strip() == ""
|
||||
|
||||
def test_no_strippable_entries(self):
|
||||
"""When there's nothing to strip, output matches input structure."""
|
||||
content = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
result = strip_progress_entries(content)
|
||||
result_lines = result.strip().split("\n")
|
||||
assert len(result_lines) == 2
|
||||
|
||||
def test_handles_entries_without_uuid(self):
|
||||
"""Entries without uuid field are handled gracefully."""
|
||||
no_uuid = {"type": "queue-operation", "subtype": "create"}
|
||||
content = _make_jsonl(no_uuid, USER_MSG, ASST_MSG)
|
||||
result = strip_progress_entries(content)
|
||||
result_types = [json.loads(line)["type"] for line in result.strip().split("\n")]
|
||||
# queue-operation is strippable
|
||||
assert "queue-operation" not in result_types
|
||||
assert "user" in result_types
|
||||
assert "assistant" in result_types
|
||||
@@ -158,41 +158,6 @@ services:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
|
||||
copilot_executor:
|
||||
build:
|
||||
context: ../
|
||||
dockerfile: autogpt_platform/backend/Dockerfile
|
||||
target: server
|
||||
command: ["python", "-m", "backend.copilot.executor"]
|
||||
develop:
|
||||
watch:
|
||||
- path: ./
|
||||
target: autogpt_platform/backend/
|
||||
action: rebuild
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
rabbitmq:
|
||||
condition: service_healthy
|
||||
db:
|
||||
condition: service_healthy
|
||||
migrate:
|
||||
condition: service_completed_successfully
|
||||
database_manager:
|
||||
condition: service_started
|
||||
<<: *backend-env-files
|
||||
environment:
|
||||
<<: *backend-env
|
||||
ports:
|
||||
- "8008:8008"
|
||||
networks:
|
||||
- app-network
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
|
||||
websocket_server:
|
||||
build:
|
||||
context: ../
|
||||
|
||||
@@ -53,12 +53,6 @@ services:
|
||||
file: ./docker-compose.platform.yml
|
||||
service: executor
|
||||
|
||||
copilot_executor:
|
||||
<<: *agpt-services
|
||||
extends:
|
||||
file: ./docker-compose.platform.yml
|
||||
service: copilot_executor
|
||||
|
||||
websocket_server:
|
||||
<<: *agpt-services
|
||||
extends:
|
||||
@@ -180,6 +174,5 @@ services:
|
||||
- deps
|
||||
- rest_server
|
||||
- executor
|
||||
- copilot_executor
|
||||
- websocket_server
|
||||
- database_manager
|
||||
|
||||
@@ -20,6 +20,7 @@ import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks";
|
||||
import { RunAgentTool } from "../../tools/RunAgent/RunAgent";
|
||||
import { RunBlockTool } from "../../tools/RunBlock/RunBlock";
|
||||
import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs";
|
||||
import { GenericTool } from "../../tools/GenericTool/GenericTool";
|
||||
import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -255,6 +256,16 @@ export const ChatMessagesContainer = ({
|
||||
/>
|
||||
);
|
||||
default:
|
||||
// Render a generic tool indicator for SDK built-in
|
||||
// tools (Read, Glob, Grep, etc.) or any unrecognized tool
|
||||
if (part.type.startsWith("tool-")) {
|
||||
return (
|
||||
<GenericTool
|
||||
key={`${message.id}-${i}`}
|
||||
part={part as ToolUIPart}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
})}
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
"use client";
|
||||
|
||||
import { ToolUIPart } from "ai";
|
||||
import { GearIcon } from "@phosphor-icons/react";
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
|
||||
interface Props {
|
||||
part: ToolUIPart;
|
||||
}
|
||||
|
||||
function extractToolName(part: ToolUIPart): string {
|
||||
// ToolUIPart.type is "tool-{name}", extract the name portion.
|
||||
return part.type.replace(/^tool-/, "");
|
||||
}
|
||||
|
||||
function formatToolName(name: string): string {
|
||||
// "search_docs" → "Search docs", "Read" → "Read"
|
||||
return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase());
|
||||
}
|
||||
|
||||
function getAnimationText(part: ToolUIPart): string {
|
||||
const label = formatToolName(extractToolName(part));
|
||||
|
||||
switch (part.state) {
|
||||
case "input-streaming":
|
||||
case "input-available":
|
||||
return `Running ${label}…`;
|
||||
case "output-available":
|
||||
return `${label} completed`;
|
||||
case "output-error":
|
||||
return `${label} failed`;
|
||||
default:
|
||||
return `Running ${label}…`;
|
||||
}
|
||||
}
|
||||
|
||||
export function GenericTool({ part }: Props) {
|
||||
const isStreaming =
|
||||
part.state === "input-streaming" || part.state === "input-available";
|
||||
const isError = part.state === "output-error";
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<GearIcon
|
||||
size={14}
|
||||
weight="regular"
|
||||
className={
|
||||
isError
|
||||
? "text-red-500"
|
||||
: isStreaming
|
||||
? "animate-spin text-neutral-500"
|
||||
: "text-neutral-400"
|
||||
}
|
||||
/>
|
||||
<MorphingTextAnimation
|
||||
text={getAnimationText(part)}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -7066,13 +7066,57 @@
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"name": { "type": "string", "title": "Name" },
|
||||
"description": { "type": "string", "title": "Description" }
|
||||
"description": { "type": "string", "title": "Description" },
|
||||
"categories": {
|
||||
"items": { "type": "string" },
|
||||
"type": "array",
|
||||
"title": "Categories"
|
||||
},
|
||||
"input_schema": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Input Schema",
|
||||
"description": "Full JSON schema for block inputs"
|
||||
},
|
||||
"output_schema": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Output Schema",
|
||||
"description": "Full JSON schema for block outputs"
|
||||
},
|
||||
"required_inputs": {
|
||||
"items": { "$ref": "#/components/schemas/BlockInputFieldInfo" },
|
||||
"type": "array",
|
||||
"title": "Required Inputs",
|
||||
"description": "List of input fields for this block"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["id", "name", "description"],
|
||||
"required": ["id", "name", "description", "categories"],
|
||||
"title": "BlockInfoSummary",
|
||||
"description": "Summary of a block for search results."
|
||||
},
|
||||
"BlockInputFieldInfo": {
|
||||
"properties": {
|
||||
"name": { "type": "string", "title": "Name" },
|
||||
"type": { "type": "string", "title": "Type" },
|
||||
"description": {
|
||||
"type": "string",
|
||||
"title": "Description",
|
||||
"default": ""
|
||||
},
|
||||
"required": {
|
||||
"type": "boolean",
|
||||
"title": "Required",
|
||||
"default": false
|
||||
},
|
||||
"default": { "anyOf": [{}, { "type": "null" }], "title": "Default" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["name", "type"],
|
||||
"title": "BlockInputFieldInfo",
|
||||
"description": "Information about a block input field."
|
||||
},
|
||||
"BlockListResponse": {
|
||||
"properties": {
|
||||
"type": {
|
||||
@@ -7090,7 +7134,12 @@
|
||||
"title": "Blocks"
|
||||
},
|
||||
"count": { "type": "integer", "title": "Count" },
|
||||
"query": { "type": "string", "title": "Query" }
|
||||
"query": { "type": "string", "title": "Query" },
|
||||
"usage_hint": {
|
||||
"type": "string",
|
||||
"title": "Usage Hint",
|
||||
"default": "To execute a block, call run_block with block_id set to the block's 'id' field and input_data containing the fields listed in required_inputs."
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["message", "blocks", "count", "query"],
|
||||
@@ -10495,7 +10544,10 @@
|
||||
"operation_started",
|
||||
"operation_pending",
|
||||
"operation_in_progress",
|
||||
"input_validation_error"
|
||||
"input_validation_error",
|
||||
"web_fetch",
|
||||
"bash_exec",
|
||||
"operation_status"
|
||||
],
|
||||
"title": "ResponseType",
|
||||
"description": "Types of tool responses."
|
||||
|
||||
Reference in New Issue
Block a user