Compare commits

..

4 Commits

Author SHA1 Message Date
Otto-AGPT
cdeefb8621 fix(copilot): Use correct OpenRouter reasoning API format
Addresses review comments from CodeRabbit and Sentry:

- Change reasoning format from {"enabled": True} (invalid) to
  {"max_tokens": config.thinking_budget_tokens} per OpenRouter docs
- Add missing thinking_budget_tokens config field (default: 10000)
- Extract duplicate code into _apply_thinking_config() helper function
- Update description from 'adaptive' to 'extended' thinking for clarity

References:
- OpenRouter reasoning docs: https://openrouter.ai/docs/reasoning-tokens
2026-02-11 13:54:57 +00:00
Swifty
ba6d585170 update settings 2026-02-10 16:08:21 +01:00
Swifty
90eac56525 Merge branch 'dev' into fix/enable-extended-thinking 2026-02-10 15:26:40 +01:00
Otto
75f8772f8a feat(copilot): Enable extended thinking for Claude models
Adds configuration to enable Anthropic's extended thinking feature via
OpenRouter. This keeps the model's chain-of-thought reasoning internal
rather than outputting it to users.

Configuration:
- thinking_enabled: bool (default: True)
- thinking_budget_tokens: int (default: 10000)

The thinking config is only applied to Anthropic models (detected via
model name containing 'anthropic').

Fixes the issue where the CoPilot prompt expects thinking mode but it
wasn't enabled on the API side, causing internal reasoning to leak
into user-facing responses.
2026-02-10 13:58:57 +00:00
270 changed files with 2307 additions and 7637 deletions

View File

@@ -22,7 +22,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
ref: ${{ github.event.workflow_run.head_branch }} ref: ${{ github.event.workflow_run.head_branch }}
fetch-depth: 0 fetch-depth: 0

View File

@@ -30,7 +30,7 @@ jobs:
actions: read # Required for CI access actions: read # Required for CI access
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
fetch-depth: 1 fetch-depth: 1

View File

@@ -40,7 +40,7 @@ jobs:
actions: read # Required for CI access actions: read # Required for CI access
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
fetch-depth: 1 fetch-depth: 1

View File

@@ -58,7 +58,7 @@ jobs:
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
# Initializes the CodeQL tools for scanning. # Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL - name: Initialize CodeQL

View File

@@ -27,7 +27,7 @@ jobs:
# If you do not check out your code, Copilot will do this for you. # If you do not check out your code, Copilot will do this for you.
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
submodules: true submodules: true

View File

@@ -23,7 +23,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
fetch-depth: 1 fetch-depth: 1

View File

@@ -23,7 +23,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0

View File

@@ -28,7 +28,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
fetch-depth: 1 fetch-depth: 1

View File

@@ -25,7 +25,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
ref: ${{ github.event.inputs.git_ref || github.ref_name }} ref: ${{ github.event.inputs.git_ref || github.ref_name }}
@@ -52,7 +52,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Trigger deploy workflow - name: Trigger deploy workflow
uses: peter-evans/repository-dispatch@v4 uses: peter-evans/repository-dispatch@v3
with: with:
token: ${{ secrets.DEPLOY_TOKEN }} token: ${{ secrets.DEPLOY_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure repository: Significant-Gravitas/AutoGPT_cloud_infrastructure

View File

@@ -17,7 +17,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
ref: ${{ github.ref_name || 'master' }} ref: ${{ github.ref_name || 'master' }}
@@ -45,7 +45,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Trigger deploy workflow - name: Trigger deploy workflow
uses: peter-evans/repository-dispatch@v4 uses: peter-evans/repository-dispatch@v3
with: with:
token: ${{ secrets.DEPLOY_TOKEN }} token: ${{ secrets.DEPLOY_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure repository: Significant-Gravitas/AutoGPT_cloud_infrastructure

View File

@@ -68,7 +68,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
submodules: true submodules: true

View File

@@ -82,7 +82,7 @@ jobs:
- name: Dispatch Deploy Event - name: Dispatch Deploy Event
if: steps.check_status.outputs.should_deploy == 'true' if: steps.check_status.outputs.should_deploy == 'true'
uses: peter-evans/repository-dispatch@v4 uses: peter-evans/repository-dispatch@v3
with: with:
token: ${{ secrets.DISPATCH_TOKEN }} token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
@@ -110,7 +110,7 @@ jobs:
- name: Dispatch Undeploy Event (from comment) - name: Dispatch Undeploy Event (from comment)
if: steps.check_status.outputs.should_undeploy == 'true' if: steps.check_status.outputs.should_undeploy == 'true'
uses: peter-evans/repository-dispatch@v4 uses: peter-evans/repository-dispatch@v3
with: with:
token: ${{ secrets.DISPATCH_TOKEN }} token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
@@ -168,7 +168,7 @@ jobs:
github.event_name == 'pull_request' && github.event_name == 'pull_request' &&
github.event.action == 'closed' && github.event.action == 'closed' &&
steps.check_pr_close.outputs.should_undeploy == 'true' steps.check_pr_close.outputs.should_undeploy == 'true'
uses: peter-evans/repository-dispatch@v4 uses: peter-evans/repository-dispatch@v3
with: with:
token: ${{ secrets.DISPATCH_TOKEN }} token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure repository: Significant-Gravitas/AutoGPT_cloud_infrastructure

View File

@@ -31,7 +31,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
- name: Check for component changes - name: Check for component changes
uses: dorny/paths-filter@v3 uses: dorny/paths-filter@v3
@@ -71,7 +71,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v6 uses: actions/setup-node@v6
@@ -107,7 +107,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
@@ -148,7 +148,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
submodules: recursive submodules: recursive
@@ -277,7 +277,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
submodules: recursive submodules: recursive

View File

@@ -29,7 +29,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v6 uses: actions/setup-node@v6
@@ -63,7 +63,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
submodules: recursive submodules: recursive

View File

@@ -11,7 +11,7 @@ jobs:
steps: steps:
# - name: Wait some time for all actions to start # - name: Wait some time for all actions to start
# run: sleep 30 # run: sleep 30
- uses: actions/checkout@v6 - uses: actions/checkout@v4
# with: # with:
# fetch-depth: 0 # fetch-depth: 0
- name: Set up Python - name: Set up Python

View File

@@ -10,7 +10,7 @@ from typing_extensions import TypedDict
import backend.api.features.store.cache as store_cache import backend.api.features.store.cache as store_cache
import backend.api.features.store.model as store_model import backend.api.features.store.model as store_model
import backend.blocks import backend.data.block
from backend.api.external.middleware import require_permission from backend.api.external.middleware import require_permission
from backend.data import execution as execution_db from backend.data import execution as execution_db
from backend.data import graph as graph_db from backend.data import graph as graph_db
@@ -67,7 +67,7 @@ async def get_user_info(
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))], dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
) )
async def get_graph_blocks() -> Sequence[dict[Any, Any]]: async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in backend.blocks.get_blocks().values()] blocks = [block() for block in backend.data.block.get_blocks().values()]
return [b.to_dict() for b in blocks if not b.disabled] return [b.to_dict() for b in blocks if not b.disabled]
@@ -83,7 +83,7 @@ async def execute_graph_block(
require_permission(APIKeyPermission.EXECUTE_BLOCK) require_permission(APIKeyPermission.EXECUTE_BLOCK)
), ),
) -> CompletedBlockOutput: ) -> CompletedBlockOutput:
obj = backend.blocks.get_block(block_id) obj = backend.data.block.get_block(block_id)
if not obj: if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.") raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
if obj.disabled: if obj.disabled:

View File

@@ -10,15 +10,10 @@ import backend.api.features.library.db as library_db
import backend.api.features.library.model as library_model import backend.api.features.library.model as library_model
import backend.api.features.store.db as store_db import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model import backend.api.features.store.model as store_model
import backend.data.block
from backend.blocks import load_all_blocks from backend.blocks import load_all_blocks
from backend.blocks._base import (
AnyBlockSchema,
BlockCategory,
BlockInfo,
BlockSchema,
BlockType,
)
from backend.blocks.llm import LlmModel from backend.blocks.llm import LlmModel
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
from backend.data.db import query_raw_with_schema from backend.data.db import query_raw_with_schema
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.util.cache import cached from backend.util.cache import cached
@@ -27,7 +22,7 @@ from backend.util.models import Pagination
from .model import ( from .model import (
BlockCategoryResponse, BlockCategoryResponse,
BlockResponse, BlockResponse,
BlockTypeFilter, BlockType,
CountResponse, CountResponse,
FilterType, FilterType,
Provider, Provider,
@@ -93,7 +88,7 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
def get_blocks( def get_blocks(
*, *,
category: str | None = None, category: str | None = None,
type: BlockTypeFilter | None = None, type: BlockType | None = None,
provider: ProviderName | None = None, provider: ProviderName | None = None,
page: int = 1, page: int = 1,
page_size: int = 50, page_size: int = 50,
@@ -674,9 +669,9 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
for block_type in load_all_blocks().values(): for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type() block: AnyBlockSchema = block_type()
if block.disabled or block.block_type in ( if block.disabled or block.block_type in (
BlockType.INPUT, backend.data.block.BlockType.INPUT,
BlockType.OUTPUT, backend.data.block.BlockType.OUTPUT,
BlockType.AGENT, backend.data.block.BlockType.AGENT,
): ):
continue continue
# Find the execution count for this block # Find the execution count for this block

View File

@@ -4,7 +4,7 @@ from pydantic import BaseModel
import backend.api.features.library.model as library_model import backend.api.features.library.model as library_model
import backend.api.features.store.model as store_model import backend.api.features.store.model as store_model
from backend.blocks._base import BlockInfo from backend.data.block import BlockInfo
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.util.models import Pagination from backend.util.models import Pagination
@@ -15,7 +15,7 @@ FilterType = Literal[
"my_agents", "my_agents",
] ]
BlockTypeFilter = Literal["all", "input", "action", "output"] BlockType = Literal["all", "input", "action", "output"]
class SearchEntry(BaseModel): class SearchEntry(BaseModel):

View File

@@ -88,7 +88,7 @@ async def get_block_categories(
) )
async def get_blocks( async def get_blocks(
category: Annotated[str | None, fastapi.Query()] = None, category: Annotated[str | None, fastapi.Query()] = None,
type: Annotated[builder_model.BlockTypeFilter | None, fastapi.Query()] = None, type: Annotated[builder_model.BlockType | None, fastapi.Query()] = None,
provider: Annotated[ProviderName | None, fastapi.Query()] = None, provider: Annotated[ProviderName | None, fastapi.Query()] = None,
page: Annotated[int, fastapi.Query()] = 1, page: Annotated[int, fastapi.Query()] = 1,
page_size: Annotated[int, fastapi.Query()] = 50, page_size: Annotated[int, fastapi.Query()] = 50,

View File

@@ -96,7 +96,13 @@ class ChatConfig(BaseSettings):
# Extended thinking configuration for Claude models # Extended thinking configuration for Claude models
thinking_enabled: bool = Field( thinking_enabled: bool = Field(
default=True, default=True,
description="Enable adaptive thinking for Claude models via OpenRouter", description="Enable extended thinking for Claude models via OpenRouter",
)
thinking_budget_tokens: int = Field(
default=10000,
ge=1000,
le=100000,
description="Maximum tokens for extended thinking (budget_tokens for Claude)",
) )
@field_validator("api_key", mode="before") @field_validator("api_key", mode="before")

View File

@@ -2,7 +2,7 @@ import asyncio
import logging import logging
import uuid import uuid
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any, cast from typing import Any
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from openai.types.chat import ( from openai.types.chat import (
@@ -104,26 +104,6 @@ class ChatSession(BaseModel):
successful_agent_runs: dict[str, int] = {} successful_agent_runs: dict[str, int] = {}
successful_agent_schedules: dict[str, int] = {} successful_agent_schedules: dict[str, int] = {}
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
"""Attach a tool_call to the current turn's assistant message.
Searches backwards for the most recent assistant message (stopping at
any user message boundary). If found, appends the tool_call to it.
Otherwise creates a new assistant message with the tool_call.
"""
for msg in reversed(self.messages):
if msg.role == "user":
break
if msg.role == "assistant":
if not msg.tool_calls:
msg.tool_calls = []
msg.tool_calls.append(tool_call)
return
self.messages.append(
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
)
@staticmethod @staticmethod
def new(user_id: str) -> "ChatSession": def new(user_id: str) -> "ChatSession":
return ChatSession( return ChatSession(
@@ -192,47 +172,6 @@ class ChatSession(BaseModel):
successful_agent_schedules=successful_agent_schedules, successful_agent_schedules=successful_agent_schedules,
) )
@staticmethod
def _merge_consecutive_assistant_messages(
messages: list[ChatCompletionMessageParam],
) -> list[ChatCompletionMessageParam]:
"""Merge consecutive assistant messages into single messages.
Long-running tool flows can create split assistant messages: one with
text content and another with tool_calls. Anthropic's API requires
tool_result blocks to reference a tool_use in the immediately preceding
assistant message, so these splits cause 400 errors via OpenRouter.
"""
if len(messages) < 2:
return messages
result: list[ChatCompletionMessageParam] = [messages[0]]
for msg in messages[1:]:
prev = result[-1]
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
result.append(msg)
continue
prev = cast(ChatCompletionAssistantMessageParam, prev)
curr = cast(ChatCompletionAssistantMessageParam, msg)
curr_content = curr.get("content") or ""
if curr_content:
prev_content = prev.get("content") or ""
prev["content"] = (
f"{prev_content}\n{curr_content}" if prev_content else curr_content
)
curr_tool_calls = curr.get("tool_calls")
if curr_tool_calls:
prev_tool_calls = prev.get("tool_calls")
prev["tool_calls"] = (
list(prev_tool_calls) + list(curr_tool_calls)
if prev_tool_calls
else list(curr_tool_calls)
)
return result
def to_openai_messages(self) -> list[ChatCompletionMessageParam]: def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
messages = [] messages = []
for message in self.messages: for message in self.messages:
@@ -319,7 +258,7 @@ class ChatSession(BaseModel):
name=message.name or "", name=message.name or "",
) )
) )
return self._merge_consecutive_assistant_messages(messages) return messages
async def _get_session_from_cache(session_id: str) -> ChatSession | None: async def _get_session_from_cache(session_id: str) -> ChatSession | None:

View File

@@ -1,16 +1,4 @@
from typing import cast
import pytest import pytest
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionUserMessageParam,
)
from openai.types.chat.chat_completion_message_tool_call_param import (
ChatCompletionMessageToolCallParam,
Function,
)
from .model import ( from .model import (
ChatMessage, ChatMessage,
@@ -129,205 +117,3 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
loaded.tool_calls is not None loaded.tool_calls is not None
), f"Tool calls missing for {orig.role} message" ), f"Tool calls missing for {orig.role} message"
assert len(orig.tool_calls) == len(loaded.tool_calls) assert len(orig.tool_calls) == len(loaded.tool_calls)
# --------------------------------------------------------------------------- #
# _merge_consecutive_assistant_messages #
# --------------------------------------------------------------------------- #
_tc = ChatCompletionMessageToolCallParam(
id="tc1", type="function", function=Function(name="do_stuff", arguments="{}")
)
_tc2 = ChatCompletionMessageToolCallParam(
id="tc2", type="function", function=Function(name="other", arguments="{}")
)
def test_merge_noop_when_no_consecutive_assistants():
"""Messages without consecutive assistants are returned unchanged."""
msgs = [
ChatCompletionUserMessageParam(role="user", content="hi"),
ChatCompletionAssistantMessageParam(role="assistant", content="hello"),
ChatCompletionUserMessageParam(role="user", content="bye"),
]
merged = ChatSession._merge_consecutive_assistant_messages(msgs)
assert len(merged) == 3
assert [m["role"] for m in merged] == ["user", "assistant", "user"]
def test_merge_splits_text_and_tool_calls():
"""The exact bug scenario: text-only assistant followed by tool_calls-only assistant."""
msgs = [
ChatCompletionUserMessageParam(role="user", content="build agent"),
ChatCompletionAssistantMessageParam(
role="assistant", content="Let me build that"
),
ChatCompletionAssistantMessageParam(
role="assistant", content="", tool_calls=[_tc]
),
ChatCompletionToolMessageParam(role="tool", content="ok", tool_call_id="tc1"),
]
merged = ChatSession._merge_consecutive_assistant_messages(msgs)
assert len(merged) == 3
assert merged[0]["role"] == "user"
assert merged[2]["role"] == "tool"
a = cast(ChatCompletionAssistantMessageParam, merged[1])
assert a["role"] == "assistant"
assert a.get("content") == "Let me build that"
assert a.get("tool_calls") == [_tc]
def test_merge_combines_tool_calls_from_both():
"""Both consecutive assistants have tool_calls — they get merged."""
msgs: list[ChatCompletionAssistantMessageParam] = [
ChatCompletionAssistantMessageParam(
role="assistant", content="text", tool_calls=[_tc]
),
ChatCompletionAssistantMessageParam(
role="assistant", content="", tool_calls=[_tc2]
),
]
merged = ChatSession._merge_consecutive_assistant_messages(msgs) # type: ignore[arg-type]
assert len(merged) == 1
a = cast(ChatCompletionAssistantMessageParam, merged[0])
assert a.get("tool_calls") == [_tc, _tc2]
assert a.get("content") == "text"
def test_merge_three_consecutive_assistants():
"""Three consecutive assistants collapse into one."""
msgs: list[ChatCompletionAssistantMessageParam] = [
ChatCompletionAssistantMessageParam(role="assistant", content="a"),
ChatCompletionAssistantMessageParam(role="assistant", content="b"),
ChatCompletionAssistantMessageParam(
role="assistant", content="", tool_calls=[_tc]
),
]
merged = ChatSession._merge_consecutive_assistant_messages(msgs) # type: ignore[arg-type]
assert len(merged) == 1
a = cast(ChatCompletionAssistantMessageParam, merged[0])
assert a.get("content") == "a\nb"
assert a.get("tool_calls") == [_tc]
def test_merge_empty_and_single_message():
"""Edge cases: empty list and single message."""
assert ChatSession._merge_consecutive_assistant_messages([]) == []
single: list[ChatCompletionMessageParam] = [
ChatCompletionUserMessageParam(role="user", content="hi")
]
assert ChatSession._merge_consecutive_assistant_messages(single) == single
# --------------------------------------------------------------------------- #
# add_tool_call_to_current_turn #
# --------------------------------------------------------------------------- #
_raw_tc = {
"id": "tc1",
"type": "function",
"function": {"name": "f", "arguments": "{}"},
}
_raw_tc2 = {
"id": "tc2",
"type": "function",
"function": {"name": "g", "arguments": "{}"},
}
def test_add_tool_call_appends_to_existing_assistant():
"""When the last assistant is from the current turn, tool_call is added to it."""
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="working on it"),
]
session.add_tool_call_to_current_turn(_raw_tc)
assert len(session.messages) == 2 # no new message created
assert session.messages[1].tool_calls == [_raw_tc]
def test_add_tool_call_creates_assistant_when_none_exists():
"""When there's no current-turn assistant, a new one is created."""
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
]
session.add_tool_call_to_current_turn(_raw_tc)
assert len(session.messages) == 2
assert session.messages[1].role == "assistant"
assert session.messages[1].tool_calls == [_raw_tc]
def test_add_tool_call_does_not_cross_user_boundary():
"""A user message acts as a boundary — previous assistant is not modified."""
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="assistant", content="old turn"),
ChatMessage(role="user", content="new message"),
]
session.add_tool_call_to_current_turn(_raw_tc)
assert len(session.messages) == 3 # new assistant was created
assert session.messages[0].tool_calls is None # old assistant untouched
assert session.messages[2].role == "assistant"
assert session.messages[2].tool_calls == [_raw_tc]
def test_add_tool_call_multiple_times():
"""Multiple long-running tool calls accumulate on the same assistant."""
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="doing stuff"),
]
session.add_tool_call_to_current_turn(_raw_tc)
# Simulate a pending tool result in between (like _yield_tool_call does)
session.messages.append(
ChatMessage(role="tool", content="pending", tool_call_id="tc1")
)
session.add_tool_call_to_current_turn(_raw_tc2)
assert len(session.messages) == 3 # user, assistant, tool — no extra assistant
assert session.messages[1].tool_calls == [_raw_tc, _raw_tc2]
def test_to_openai_messages_merges_split_assistants():
"""End-to-end: session with split assistants produces valid OpenAI messages."""
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="build agent"),
ChatMessage(role="assistant", content="Let me build that"),
ChatMessage(
role="assistant",
content="",
tool_calls=[
{
"id": "tc1",
"type": "function",
"function": {"name": "create_agent", "arguments": "{}"},
}
],
),
ChatMessage(role="tool", content="done", tool_call_id="tc1"),
ChatMessage(role="assistant", content="Saved!"),
ChatMessage(role="user", content="show me an example run"),
]
openai_msgs = session.to_openai_messages()
# The two consecutive assistants at index 1,2 should be merged
roles = [m["role"] for m in openai_msgs]
assert roles == ["user", "assistant", "tool", "assistant", "user"]
# The merged assistant should have both content and tool_calls
merged = cast(ChatCompletionAssistantMessageParam, openai_msgs[1])
assert merged.get("content") == "Let me build that"
tc_list = merged.get("tool_calls")
assert tc_list is not None and len(list(tc_list)) == 1
assert list(tc_list)[0]["id"] == "tc1"

View File

@@ -10,8 +10,6 @@ from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from backend.util.json import dumps as json_dumps
class ResponseType(str, Enum): class ResponseType(str, Enum):
"""Types of streaming responses following AI SDK protocol.""" """Types of streaming responses following AI SDK protocol."""
@@ -195,18 +193,6 @@ class StreamError(StreamBaseResponse):
default=None, description="Additional error details" default=None, description="Additional error details"
) )
def to_sse(self) -> str:
"""Convert to SSE format, only emitting fields required by AI SDK protocol.
The AI SDK uses z.strictObject({type, errorText}) which rejects
any extra fields like `code` or `details`.
"""
data = {
"type": self.type.value,
"errorText": self.errorText,
}
return f"data: {json_dumps(data)}\n\n"
class StreamHeartbeat(StreamBaseResponse): class StreamHeartbeat(StreamBaseResponse):
"""Heartbeat to keep SSE connection alive during long-running operations. """Heartbeat to keep SSE connection alive during long-running operations.

View File

@@ -80,6 +80,19 @@ settings = Settings()
client = openai.AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) client = openai.AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
def _apply_thinking_config(extra_body: dict[str, Any], model: str) -> None:
"""Apply extended thinking configuration for Anthropic models via OpenRouter.
OpenRouter's reasoning API expects either:
- {"max_tokens": N} for explicit token budget
- {"effort": "high"} for automatic budget
See: https://openrouter.ai/docs/reasoning-tokens
"""
if config.thinking_enabled and "anthropic" in model.lower():
extra_body["reasoning"] = {"max_tokens": config.thinking_budget_tokens}
langfuse = get_client() langfuse = get_client()
# Redis key prefix for tracking running long-running operations # Redis key prefix for tracking running long-running operations
@@ -800,13 +813,9 @@ async def stream_chat_completion(
# Build the messages list in the correct order # Build the messages list in the correct order
messages_to_save: list[ChatMessage] = [] messages_to_save: list[ChatMessage] = []
# Add assistant message with tool_calls if any. # Add assistant message with tool_calls if any
# Use extend (not assign) to preserve tool_calls already added by
# _yield_tool_call for long-running tools.
if accumulated_tool_calls: if accumulated_tool_calls:
if not assistant_response.tool_calls: assistant_response.tool_calls = accumulated_tool_calls
assistant_response.tool_calls = []
assistant_response.tool_calls.extend(accumulated_tool_calls)
logger.info( logger.info(
f"Added {len(accumulated_tool_calls)} tool calls to assistant message" f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
) )
@@ -1070,9 +1079,8 @@ async def _stream_chat_chunks(
:128 :128
] # OpenRouter limit ] # OpenRouter limit
# Enable adaptive thinking for Anthropic models via OpenRouter # Enable extended thinking for Anthropic models via OpenRouter
if config.thinking_enabled and "anthropic" in model.lower(): _apply_thinking_config(extra_body, model)
extra_body["reasoning"] = {"enabled": True}
api_call_start = time_module.perf_counter() api_call_start = time_module.perf_counter()
stream = await client.chat.completions.create( stream = await client.chat.completions.create(
@@ -1408,9 +1416,13 @@ async def _yield_tool_call(
operation_id=operation_id, operation_id=operation_id,
) )
# Attach the tool_call to the current turn's assistant message # Save assistant message with tool_call FIRST (required by LLM)
# (or create one if this is a tool-only response with no text). assistant_message = ChatMessage(
session.add_tool_call_to_current_turn(tool_calls[yield_idx]) role="assistant",
content="",
tool_calls=[tool_calls[yield_idx]],
)
session.messages.append(assistant_message)
# Then save pending tool result # Then save pending tool result
pending_message = ChatMessage( pending_message = ChatMessage(
@@ -1833,9 +1845,8 @@ async def _generate_llm_continuation(
if session_id: if session_id:
extra_body["session_id"] = session_id[:128] extra_body["session_id"] = session_id[:128]
# Enable adaptive thinking for Anthropic models via OpenRouter # Enable extended thinking for Anthropic models via OpenRouter
if config.thinking_enabled and "anthropic" in config.model.lower(): _apply_thinking_config(extra_body, config.model)
extra_body["reasoning"] = {"enabled": True}
retry_count = 0 retry_count = 0
last_error: Exception | None = None last_error: Exception | None = None
@@ -1967,9 +1978,8 @@ async def _generate_llm_continuation_with_streaming(
if session_id: if session_id:
extra_body["session_id"] = session_id[:128] extra_body["session_id"] = session_id[:128]
# Enable adaptive thinking for Anthropic models via OpenRouter # Enable extended thinking for Anthropic models via OpenRouter
if config.thinking_enabled and "anthropic" in config.model.lower(): _apply_thinking_config(extra_body, config.model)
extra_body["reasoning"] = {"enabled": True}
# Make streaming LLM call (no tools - just text response) # Make streaming LLM call (no tools - just text response)
from typing import cast from typing import cast

View File

@@ -13,8 +13,7 @@ from backend.api.features.chat.tools.models import (
NoResultsResponse, NoResultsResponse,
) )
from backend.api.features.store.hybrid_search import unified_hybrid_search from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.blocks import get_block from backend.data.block import BlockType, get_block
from backend.blocks._base import BlockType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -10,7 +10,7 @@ from backend.api.features.chat.tools.find_block import (
FindBlockTool, FindBlockTool,
) )
from backend.api.features.chat.tools.models import BlockListResponse from backend.api.features.chat.tools.models import BlockListResponse
from backend.blocks._base import BlockType from backend.data.block import BlockType
from ._test_data import make_session from ._test_data import make_session

View File

@@ -12,8 +12,7 @@ from backend.api.features.chat.tools.find_block import (
COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES, COPILOT_EXCLUDED_BLOCK_TYPES,
) )
from backend.blocks import get_block from backend.data.block import AnyBlockSchema, get_block
from backend.blocks._base import AnyBlockSchema
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace from backend.data.workspace import get_or_create_workspace

View File

@@ -6,7 +6,7 @@ import pytest
from backend.api.features.chat.tools.models import ErrorResponse from backend.api.features.chat.tools.models import ErrorResponse
from backend.api.features.chat.tools.run_block import RunBlockTool from backend.api.features.chat.tools.run_block import RunBlockTool
from backend.blocks._base import BlockType from backend.data.block import BlockType
from ._test_data import make_session from ._test_data import make_session

View File

@@ -15,7 +15,6 @@ from backend.data.model import (
OAuth2Credentials, OAuth2Credentials,
) )
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.exceptions import NotFoundError from backend.util.exceptions import NotFoundError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -360,7 +359,7 @@ async def match_user_credentials_to_graph(
_, _,
_, _,
) in aggregated_creds.items(): ) in aggregated_creds.items():
# Find first matching credential by provider, type, scopes, and host/URL # Find first matching credential by provider, type, and scopes
matching_cred = next( matching_cred = next(
( (
cred cred
@@ -375,10 +374,6 @@ async def match_user_credentials_to_graph(
cred.type != "host_scoped" cred.type != "host_scoped"
or _credential_is_for_host(cred, credential_requirements) or _credential_is_for_host(cred, credential_requirements)
) )
and (
cred.provider != ProviderName.MCP
or _credential_is_for_mcp_server(cred, credential_requirements)
)
), ),
None, None,
) )
@@ -449,22 +444,6 @@ def _credential_is_for_host(
return credential.matches_url(list(requirements.discriminator_values)[0]) return credential.matches_url(list(requirements.discriminator_values)[0])
def _credential_is_for_mcp_server(
credential: Credentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if an MCP OAuth credential matches the required server URL."""
if not requirements.discriminator_values:
return True
server_url = (
credential.metadata.get("mcp_server_url")
if isinstance(credential, OAuth2Credentials)
else None
)
return server_url in requirements.discriminator_values if server_url else False
async def check_user_has_required_credentials( async def check_user_has_required_credentials(
user_id: str, user_id: str,
required_credentials: list[CredentialsMetaInput], required_credentials: list[CredentialsMetaInput],

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Annotated, Any, List, Literal from typing import TYPE_CHECKING, Annotated, List, Literal
from autogpt_libs.auth import get_user_id from autogpt_libs.auth import get_user_id
from fastapi import ( from fastapi import (
@@ -14,7 +14,7 @@ from fastapi import (
Security, Security,
status, status,
) )
from pydantic import BaseModel, Field, SecretStr, model_validator from pydantic import BaseModel, Field, SecretStr
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
from backend.api.features.library.db import set_preset_webhook, update_preset from backend.api.features.library.db import set_preset_webhook, update_preset
@@ -39,11 +39,7 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
from backend.integrations.credentials_store import provider_matches from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
create_mcp_oauth_handler,
)
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager from backend.integrations.webhooks import get_webhook_manager
@@ -106,37 +102,9 @@ class CredentialsMetaResponse(BaseModel):
scopes: list[str] | None scopes: list[str] | None
username: str | None username: str | None
host: str | None = Field( host: str | None = Field(
default=None, default=None, description="Host pattern for host-scoped credentials"
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
) )
@model_validator(mode="before")
@classmethod
def _normalize_provider(cls, data: Any) -> Any:
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug."""
if isinstance(data, dict):
prov = data.get("provider", "")
if isinstance(prov, str) and prov.startswith("ProviderName."):
member = prov.removeprefix("ProviderName.")
try:
data = {**data, "provider": ProviderName[member].value}
except KeyError:
pass
return data
@staticmethod
def get_host(cred: Credentials) -> str | None:
"""Extract host from credential: HostScoped host or MCP server URL."""
if isinstance(cred, HostScopedCredentials):
return cred.host
if isinstance(cred, OAuth2Credentials) and cred.provider in (
ProviderName.MCP,
ProviderName.MCP.value,
"ProviderName.MCP",
):
return (cred.metadata or {}).get("mcp_server_url")
return None
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens") @router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
async def callback( async def callback(
@@ -211,7 +179,9 @@ async def callback(
title=credentials.title, title=credentials.title,
scopes=credentials.scopes, scopes=credentials.scopes,
username=credentials.username, username=credentials.username,
host=(CredentialsMetaResponse.get_host(credentials)), host=(
credentials.host if isinstance(credentials, HostScopedCredentials) else None
),
) )
@@ -229,7 +199,7 @@ async def list_credentials(
title=cred.title, title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None, scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None, username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred), host=cred.host if isinstance(cred, HostScopedCredentials) else None,
) )
for cred in credentials for cred in credentials
] ]
@@ -252,7 +222,7 @@ async def list_credentials_by_provider(
title=cred.title, title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None, scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None, username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred), host=cred.host if isinstance(cred, HostScopedCredentials) else None,
) )
for cred in credentials for cred in credentials
] ]
@@ -352,11 +322,7 @@ async def delete_credentials(
tokens_revoked = None tokens_revoked = None
if isinstance(creds, OAuth2Credentials): if isinstance(creds, OAuth2Credentials):
if provider_matches(provider.value, ProviderName.MCP.value): handler = _get_provider_oauth_handler(request, provider)
# MCP uses dynamic per-server OAuth — create handler from metadata
handler = create_mcp_oauth_handler(creds)
else:
handler = _get_provider_oauth_handler(request, provider)
tokens_revoked = await handler.revoke_tokens(creds) tokens_revoked = await handler.revoke_tokens(creds)
return CredentialsDeletionResponse(revoked=tokens_revoked) return CredentialsDeletionResponse(revoked=tokens_revoked)

View File

@@ -12,11 +12,12 @@ import backend.api.features.store.image_gen as store_image_gen
import backend.api.features.store.media as store_media import backend.api.features.store.media as store_media
import backend.data.graph as graph_db import backend.data.graph as graph_db
import backend.data.integrations as integrations_db import backend.data.integrations as integrations_db
from backend.data.block import BlockInput
from backend.data.db import transaction from backend.data.db import transaction
from backend.data.execution import get_graph_execution from backend.data.execution import get_graph_execution
from backend.data.graph import GraphSettings from backend.data.graph import GraphSettings
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
from backend.data.model import CredentialsMetaInput, GraphInput from backend.data.model import CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import ( from backend.integrations.webhooks.graph_lifecycle_hooks import (
on_graph_activate, on_graph_activate,
@@ -1129,7 +1130,7 @@ async def create_preset_from_graph_execution(
async def update_preset( async def update_preset(
user_id: str, user_id: str,
preset_id: str, preset_id: str,
inputs: Optional[GraphInput] = None, inputs: Optional[BlockInput] = None,
credentials: Optional[dict[str, CredentialsMetaInput]] = None, credentials: Optional[dict[str, CredentialsMetaInput]] = None,
name: Optional[str] = None, name: Optional[str] = None,
description: Optional[str] = None, description: Optional[str] = None,

View File

@@ -6,12 +6,9 @@ import prisma.enums
import prisma.models import prisma.models
import pydantic import pydantic
from backend.data.block import BlockInput
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
from backend.data.model import ( from backend.data.model import CredentialsMetaInput, is_credentials_field_name
CredentialsMetaInput,
GraphInput,
is_credentials_field_name,
)
from backend.util.json import loads as json_loads from backend.util.json import loads as json_loads
from backend.util.models import Pagination from backend.util.models import Pagination
@@ -326,7 +323,7 @@ class LibraryAgentPresetCreatable(pydantic.BaseModel):
graph_id: str graph_id: str
graph_version: int graph_version: int
inputs: GraphInput inputs: BlockInput
credentials: dict[str, CredentialsMetaInput] credentials: dict[str, CredentialsMetaInput]
name: str name: str
@@ -355,7 +352,7 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
Request model used when updating a preset for a library agent. Request model used when updating a preset for a library agent.
""" """
inputs: Optional[GraphInput] = None inputs: Optional[BlockInput] = None
credentials: Optional[dict[str, CredentialsMetaInput]] = None credentials: Optional[dict[str, CredentialsMetaInput]] = None
name: Optional[str] = None name: Optional[str] = None
@@ -398,7 +395,7 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
"Webhook must be included in AgentPreset query when webhookId is set" "Webhook must be included in AgentPreset query when webhookId is set"
) )
input_data: GraphInput = {} input_data: BlockInput = {}
input_credentials: dict[str, CredentialsMetaInput] = {} input_credentials: dict[str, CredentialsMetaInput] = {}
for preset_input in preset.InputPresets: for preset_input in preset.InputPresets:

View File

@@ -1,404 +0,0 @@
"""
MCP (Model Context Protocol) API routes.
Provides endpoints for MCP tool discovery and OAuth authentication so the
frontend can list available tools on an MCP server before placing a block.
"""
import logging
from typing import Annotated, Any
from urllib.parse import urlparse
import fastapi
from autogpt_libs.auth import get_user_id
from fastapi import Security
from pydantic import BaseModel, Field
from backend.api.features.integrations.router import CredentialsMetaResponse
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
router = fastapi.APIRouter(tags=["mcp"])
creds_manager = IntegrationCredentialsManager()
# ====================== Tool Discovery ====================== #
class DiscoverToolsRequest(BaseModel):
"""Request to discover tools on an MCP server."""
server_url: str = Field(description="URL of the MCP server")
auth_token: str | None = Field(
default=None,
description="Optional Bearer token for authenticated MCP servers",
)
class MCPToolResponse(BaseModel):
"""A single MCP tool returned by discovery."""
name: str
description: str
input_schema: dict[str, Any]
class DiscoverToolsResponse(BaseModel):
"""Response containing the list of tools available on an MCP server."""
tools: list[MCPToolResponse]
server_name: str | None = None
protocol_version: str | None = None
@router.post(
"/discover-tools",
summary="Discover available tools on an MCP server",
response_model=DiscoverToolsResponse,
)
async def discover_tools(
request: DiscoverToolsRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> DiscoverToolsResponse:
"""
Connect to an MCP server and return its available tools.
If the user has a stored MCP credential for this server URL, it will be
used automatically — no need to pass an explicit auth token.
"""
auth_token = request.auth_token
# Auto-use stored MCP credential when no explicit token is provided.
if not auth_token:
mcp_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
# Find the freshest credential for this server URL
best_cred: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == request.server_url
):
if best_cred is None or (
(cred.access_token_expires_at or 0)
> (best_cred.access_token_expires_at or 0)
):
best_cred = cred
if best_cred:
# Refresh the token if expired before using it
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
logger.info(
f"Using MCP credential {best_cred.id} for {request.server_url}, "
f"expires_at={best_cred.access_token_expires_at}"
)
auth_token = best_cred.access_token.get_secret_value()
client = MCPClient(request.server_url, auth_token=auth_token)
try:
init_result = await client.initialize()
tools = await client.list_tools()
except HTTPClientError as e:
if e.status_code in (401, 403):
raise fastapi.HTTPException(
status_code=401,
detail="This MCP server requires authentication. "
"Please provide a valid auth token.",
)
raise fastapi.HTTPException(status_code=502, detail=str(e))
except MCPClientError as e:
raise fastapi.HTTPException(status_code=502, detail=str(e))
except Exception as e:
raise fastapi.HTTPException(
status_code=502,
detail=f"Failed to connect to MCP server: {e}",
)
return DiscoverToolsResponse(
tools=[
MCPToolResponse(
name=t.name,
description=t.description,
input_schema=t.input_schema,
)
for t in tools
],
server_name=(
init_result.get("serverInfo", {}).get("name")
or urlparse(request.server_url).hostname
or "MCP"
),
protocol_version=init_result.get("protocolVersion"),
)
# ======================== OAuth Flow ======================== #
class MCPOAuthLoginRequest(BaseModel):
"""Request to start an OAuth flow for an MCP server."""
server_url: str = Field(description="URL of the MCP server that requires OAuth")
class MCPOAuthLoginResponse(BaseModel):
"""Response with the OAuth login URL for the user to authenticate."""
login_url: str
state_token: str
@router.post(
"/oauth/login",
summary="Initiate OAuth login for an MCP server",
)
async def mcp_oauth_login(
request: MCPOAuthLoginRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> MCPOAuthLoginResponse:
"""
Discover OAuth metadata from the MCP server and return a login URL.
1. Discovers the protected-resource metadata (RFC 9728)
2. Fetches the authorization server metadata (RFC 8414)
3. Performs Dynamic Client Registration (RFC 7591) if available
4. Returns the authorization URL for the frontend to open in a popup
"""
client = MCPClient(request.server_url)
# Step 1: Discover protected-resource metadata (RFC 9728)
protected_resource = await client.discover_auth()
metadata: dict[str, Any] | None = None
if protected_resource and protected_resource.get("authorization_servers"):
auth_server_url = protected_resource["authorization_servers"][0]
resource_url = protected_resource.get("resource", request.server_url)
# Step 2a: Discover auth-server metadata (RFC 8414)
metadata = await client.discover_auth_server_metadata(auth_server_url)
else:
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
# and serve OAuth metadata directly without protected-resource metadata.
# Don't assume a resource_url — omitting it lets the auth server choose
# the correct audience for the token (RFC 8707 resource is optional).
resource_url = None
metadata = await client.discover_auth_server_metadata(request.server_url)
if (
not metadata
or "authorization_endpoint" not in metadata
or "token_endpoint" not in metadata
):
raise fastapi.HTTPException(
status_code=400,
detail="This MCP server does not advertise OAuth support. "
"You may need to provide an auth token manually.",
)
authorize_url = metadata["authorization_endpoint"]
token_url = metadata["token_endpoint"]
registration_endpoint = metadata.get("registration_endpoint")
revoke_url = metadata.get("revocation_endpoint")
# Step 3: Dynamic Client Registration (RFC 7591) if available
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
raise fastapi.HTTPException(
status_code=500,
detail="Frontend base URL is not configured.",
)
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
client_id = ""
client_secret = ""
if registration_endpoint:
reg_result = await _register_mcp_client(
registration_endpoint, redirect_uri, request.server_url
)
if reg_result:
client_id = reg_result.get("client_id", "")
client_secret = reg_result.get("client_secret", "")
if not client_id:
client_id = "autogpt-platform"
# Step 4: Store state token with OAuth metadata for the callback
scopes = (protected_resource or {}).get("scopes_supported") or metadata.get(
"scopes_supported", []
)
state_token, code_challenge = await creds_manager.store.store_state_token(
user_id,
ProviderName.MCP.value,
scopes,
state_metadata={
"authorize_url": authorize_url,
"token_url": token_url,
"revoke_url": revoke_url,
"resource_url": resource_url,
"server_url": request.server_url,
"client_id": client_id,
"client_secret": client_secret,
},
)
# Step 5: Build and return the login URL
handler = MCPOAuthHandler(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
authorize_url=authorize_url,
token_url=token_url,
resource_url=resource_url,
)
login_url = handler.get_login_url(
scopes, state_token, code_challenge=code_challenge
)
return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token)
class MCPOAuthCallbackRequest(BaseModel):
"""Request to exchange an OAuth code for tokens."""
code: str = Field(description="Authorization code from OAuth callback")
state_token: str = Field(description="State token for CSRF verification")
class MCPOAuthCallbackResponse(BaseModel):
"""Response after successfully storing OAuth credentials."""
credential_id: str
@router.post(
"/oauth/callback",
summary="Exchange OAuth code for MCP tokens",
)
async def mcp_oauth_callback(
request: MCPOAuthCallbackRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> CredentialsMetaResponse:
"""
Exchange the authorization code for tokens and store the credential.
The frontend calls this after receiving the OAuth code from the popup.
On success, subsequent ``/discover-tools`` calls for the same server URL
will automatically use the stored credential.
"""
valid_state = await creds_manager.store.verify_state_token(
user_id, request.state_token, ProviderName.MCP.value
)
if not valid_state:
raise fastapi.HTTPException(
status_code=400,
detail="Invalid or expired state token.",
)
meta = valid_state.state_metadata
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
raise fastapi.HTTPException(
status_code=500,
detail="Frontend base URL is not configured.",
)
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
handler = MCPOAuthHandler(
client_id=meta["client_id"],
client_secret=meta.get("client_secret", ""),
redirect_uri=redirect_uri,
authorize_url=meta["authorize_url"],
token_url=meta["token_url"],
revoke_url=meta.get("revoke_url"),
resource_url=meta.get("resource_url"),
)
try:
credentials = await handler.exchange_code_for_tokens(
request.code, valid_state.scopes, valid_state.code_verifier
)
except Exception as e:
raise fastapi.HTTPException(
status_code=400,
detail=f"OAuth token exchange failed: {e}",
)
# Enrich credential metadata for future lookup and token refresh
if credentials.metadata is None:
credentials.metadata = {}
credentials.metadata["mcp_server_url"] = meta["server_url"]
credentials.metadata["mcp_client_id"] = meta["client_id"]
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
credentials.metadata["mcp_token_url"] = meta["token_url"]
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
credentials.title = f"MCP: {hostname}"
# Remove old MCP credentials for the same server to prevent stale token buildup.
try:
old_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
for old in old_creds:
if (
isinstance(old, OAuth2Credentials)
and (old.metadata or {}).get("mcp_server_url") == meta["server_url"]
):
await creds_manager.store.delete_creds_by_id(user_id, old.id)
logger.info(
f"Removed old MCP credential {old.id} for {meta['server_url']}"
)
except Exception:
logger.debug("Could not clean up old MCP credentials", exc_info=True)
await creds_manager.create(user_id, credentials)
return CredentialsMetaResponse(
id=credentials.id,
provider=credentials.provider,
type=credentials.type,
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=credentials.metadata.get("mcp_server_url"),
)
# ======================== Helpers ======================== #
async def _register_mcp_client(
registration_endpoint: str,
redirect_uri: str,
server_url: str,
) -> dict[str, Any] | None:
"""Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server."""
try:
response = await Requests(raise_for_status=True).post(
registration_endpoint,
json={
"client_name": "AutoGPT Platform",
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code"],
"response_types": ["code"],
"token_endpoint_auth_method": "client_secret_post",
},
)
data = response.json()
if isinstance(data, dict) and "client_id" in data:
return data
return None
except Exception as e:
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
return None

View File

@@ -1,436 +0,0 @@
"""Tests for MCP API routes.
Uses httpx.AsyncClient with ASGITransport instead of fastapi.testclient.TestClient
to avoid creating blocking portals that can corrupt pytest-asyncio's session event loop.
"""
from unittest.mock import AsyncMock, patch
import fastapi
import httpx
import pytest
import pytest_asyncio
from autogpt_libs.auth import get_user_id
from backend.api.features.mcp.routes import router
from backend.blocks.mcp.client import MCPClientError, MCPTool
from backend.util.request import HTTPClientError
app = fastapi.FastAPI()
app.include_router(router)
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
@pytest_asyncio.fixture(scope="module")
async def client():
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
yield c
class TestDiscoverTools:
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_success(self, client):
mock_tools = [
MCPTool(
name="get_weather",
description="Get weather for a city",
input_schema={
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
),
MCPTool(
name="add_numbers",
description="Add two numbers",
input_schema={
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"},
},
},
),
]
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={
"protocolVersion": "2025-03-26",
"serverInfo": {"name": "test-server"},
}
)
instance.list_tools = AsyncMock(return_value=mock_tools)
response = await client.post(
"/discover-tools",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
data = response.json()
assert len(data["tools"]) == 2
assert data["tools"][0]["name"] == "get_weather"
assert data["tools"][1]["name"] == "add_numbers"
assert data["server_name"] == "test-server"
assert data["protocol_version"] == "2025-03-26"
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_with_auth_token(self, client):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
)
instance.list_tools = AsyncMock(return_value=[])
response = await client.post(
"/discover-tools",
json={
"server_url": "https://mcp.example.com/mcp",
"auth_token": "my-secret-token",
},
)
assert response.status_code == 200
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="my-secret-token",
)
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_auto_uses_stored_credential(self, client):
"""When no explicit token is given, stored MCP credentials are used."""
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
stored_cred = OAuth2Credentials(
provider="mcp",
title="MCP: example.com",
access_token=SecretStr("stored-token-123"),
refresh_token=None,
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
)
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
)
instance.list_tools = AsyncMock(return_value=[])
response = await client.post(
"/discover-tools",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="stored-token-123",
)
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_mcp_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=MCPClientError("Connection refused")
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://bad-server.example.com/mcp"},
)
assert response.status_code == 502
assert "Connection refused" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_generic_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
response = await client.post(
"/discover-tools",
json={"server_url": "https://timeout.example.com/mcp"},
)
assert response.status_code == 502
assert "Failed to connect" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_auth_required(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://auth-server.example.com/mcp"},
)
assert response.status_code == 401
assert "requires authentication" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_forbidden(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://auth-server.example.com/mcp"},
)
assert response.status_code == 401
assert "requires authentication" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_missing_url(self, client):
response = await client.post("/discover-tools", json={})
assert response.status_code == 422
class TestOAuthLogin:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_success(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch(
"backend.api.features.mcp.routes._register_mcp_client"
) as mock_register,
):
instance = MockClient.return_value
instance.discover_auth = AsyncMock(
return_value={
"authorization_servers": ["https://auth.sentry.io"],
"resource": "https://mcp.sentry.dev/mcp",
"scopes_supported": ["openid"],
}
)
instance.discover_auth_server_metadata = AsyncMock(
return_value={
"authorization_endpoint": "https://auth.sentry.io/authorize",
"token_endpoint": "https://auth.sentry.io/token",
"registration_endpoint": "https://auth.sentry.io/register",
}
)
mock_register.return_value = {
"client_id": "registered-client-id",
"client_secret": "registered-secret",
}
mock_cm.store.store_state_token = AsyncMock(
return_value=("state-token-123", "code-challenge-abc")
)
mock_settings.config.frontend_base_url = "http://localhost:3000"
response = await client.post(
"/oauth/login",
json={"server_url": "https://mcp.sentry.dev/mcp"},
)
assert response.status_code == 200
data = response.json()
assert "login_url" in data
assert data["state_token"] == "state-token-123"
assert "auth.sentry.io/authorize" in data["login_url"]
assert "registered-client-id" in data["login_url"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_no_oauth_support(self, client):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.discover_auth = AsyncMock(return_value=None)
instance.discover_auth_server_metadata = AsyncMock(return_value=None)
response = await client.post(
"/oauth/login",
json={"server_url": "https://simple-server.example.com/mcp"},
)
assert response.status_code == 400
assert "does not advertise OAuth" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_fallback_to_public_client(self, client):
"""When DCR is unavailable, falls back to default public client ID."""
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
):
instance = MockClient.return_value
instance.discover_auth = AsyncMock(
return_value={
"authorization_servers": ["https://auth.example.com"],
"resource": "https://mcp.example.com/mcp",
}
)
instance.discover_auth_server_metadata = AsyncMock(
return_value={
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
# No registration_endpoint
}
)
mock_cm.store.store_state_token = AsyncMock(
return_value=("state-abc", "challenge-xyz")
)
mock_settings.config.frontend_base_url = "http://localhost:3000"
response = await client.post(
"/oauth/login",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
data = response.json()
assert "autogpt-platform" in data["login_url"]
class TestOAuthCallback:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_success(self, client):
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
mock_creds = OAuth2Credentials(
provider="mcp",
title=None,
access_token=SecretStr("access-token-xyz"),
refresh_token=None,
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
metadata={
"mcp_token_url": "https://auth.sentry.io/token",
"mcp_resource_url": "https://mcp.sentry.dev/mcp",
},
)
with (
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
):
mock_settings.config.frontend_base_url = "http://localhost:3000"
# Mock state verification
mock_state = AsyncMock()
mock_state.state_metadata = {
"authorize_url": "https://auth.sentry.io/authorize",
"token_url": "https://auth.sentry.io/token",
"client_id": "test-client-id",
"client_secret": "test-secret",
"server_url": "https://mcp.sentry.dev/mcp",
}
mock_state.scopes = ["openid"]
mock_state.code_verifier = "verifier-123"
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
mock_cm.create = AsyncMock()
handler_instance = MockHandler.return_value
handler_instance.exchange_code_for_tokens = AsyncMock(
return_value=mock_creds
)
# Mock old credential cleanup
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
response = await client.post(
"/oauth/callback",
json={"code": "auth-code-abc", "state_token": "state-token-123"},
)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["provider"] == "mcp"
assert data["type"] == "oauth2"
mock_cm.create.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_invalid_state(self, client):
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
response = await client.post(
"/oauth/callback",
json={"code": "auth-code", "state_token": "bad-state"},
)
assert response.status_code == 400
assert "Invalid or expired" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_token_exchange_fails(self, client):
with (
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
):
mock_settings.config.frontend_base_url = "http://localhost:3000"
mock_state = AsyncMock()
mock_state.state_metadata = {
"authorize_url": "https://auth.example.com/authorize",
"token_url": "https://auth.example.com/token",
"client_id": "cid",
"server_url": "https://mcp.example.com/mcp",
}
mock_state.scopes = []
mock_state.code_verifier = "v"
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
handler_instance = MockHandler.return_value
handler_instance.exchange_code_for_tokens = AsyncMock(
side_effect=RuntimeError("Token exchange failed")
)
response = await client.post(
"/oauth/callback",
json={"code": "bad-code", "state_token": "state"},
)
assert response.status_code == 400
assert "token exchange failed" in response.json()["detail"].lower()

View File

@@ -5,8 +5,8 @@ from typing import Optional
import aiohttp import aiohttp
from fastapi import HTTPException from fastapi import HTTPException
from backend.blocks import get_block
from backend.data import graph as graph_db from backend.data import graph as graph_db
from backend.data.block import get_block
from backend.util.settings import Settings from backend.util.settings import Settings
from .models import ApiResponse, ChatRequest, GraphData from .models import ApiResponse, ChatRequest, GraphData

View File

@@ -152,7 +152,7 @@ class BlockHandler(ContentHandler):
async def get_missing_items(self, batch_size: int) -> list[ContentItem]: async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
"""Fetch blocks without embeddings.""" """Fetch blocks without embeddings."""
from backend.blocks import get_blocks from backend.data.block import get_blocks
# Get all available blocks # Get all available blocks
all_blocks = get_blocks() all_blocks = get_blocks()
@@ -249,7 +249,7 @@ class BlockHandler(ContentHandler):
async def get_stats(self) -> dict[str, int]: async def get_stats(self) -> dict[str, int]:
"""Get statistics about block embedding coverage.""" """Get statistics about block embedding coverage."""
from backend.blocks import get_blocks from backend.data.block import get_blocks
all_blocks = get_blocks() all_blocks = get_blocks()

View File

@@ -93,7 +93,7 @@ async def test_block_handler_get_missing_items(mocker):
mock_existing = [] mock_existing = []
with patch( with patch(
"backend.blocks.get_blocks", "backend.data.block.get_blocks",
return_value=mock_blocks, return_value=mock_blocks,
): ):
with patch( with patch(
@@ -135,7 +135,7 @@ async def test_block_handler_get_stats(mocker):
mock_embedded = [{"count": 2}] mock_embedded = [{"count": 2}]
with patch( with patch(
"backend.blocks.get_blocks", "backend.data.block.get_blocks",
return_value=mock_blocks, return_value=mock_blocks,
): ):
with patch( with patch(
@@ -327,7 +327,7 @@ async def test_block_handler_handles_missing_attributes():
mock_blocks = {"block-minimal": mock_block_class} mock_blocks = {"block-minimal": mock_block_class}
with patch( with patch(
"backend.blocks.get_blocks", "backend.data.block.get_blocks",
return_value=mock_blocks, return_value=mock_blocks,
): ):
with patch( with patch(
@@ -360,7 +360,7 @@ async def test_block_handler_skips_failed_blocks():
mock_blocks = {"good-block": good_block, "bad-block": bad_block} mock_blocks = {"good-block": good_block, "bad-block": bad_block}
with patch( with patch(
"backend.blocks.get_blocks", "backend.data.block.get_blocks",
return_value=mock_blocks, return_value=mock_blocks,
): ):
with patch( with patch(

View File

@@ -662,7 +662,7 @@ async def cleanup_orphaned_embeddings() -> dict[str, Any]:
) )
current_ids = {row["id"] for row in valid_agents} current_ids = {row["id"] for row in valid_agents}
elif content_type == ContentType.BLOCK: elif content_type == ContentType.BLOCK:
from backend.blocks import get_blocks from backend.data.block import get_blocks
current_ids = set(get_blocks().keys()) current_ids = set(get_blocks().keys())
elif content_type == ContentType.DOCUMENTATION: elif content_type == ContentType.DOCUMENTATION:

View File

@@ -7,6 +7,15 @@ from replicate.client import Client as ReplicateClient
from replicate.exceptions import ReplicateError from replicate.exceptions import ReplicateError
from replicate.helpers import FileOutput from replicate.helpers import FileOutput
from backend.blocks.ideogram import (
AspectRatio,
ColorPalettePreset,
IdeogramModelBlock,
IdeogramModelName,
MagicPromptOption,
StyleType,
UpscaleOption,
)
from backend.data.graph import GraphBaseMeta from backend.data.graph import GraphBaseMeta
from backend.data.model import CredentialsMetaInput, ProviderName from backend.data.model import CredentialsMetaInput, ProviderName
from backend.integrations.credentials_store import ideogram_credentials from backend.integrations.credentials_store import ideogram_credentials
@@ -41,16 +50,6 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
if not ideogram_credentials.api_key: if not ideogram_credentials.api_key:
raise ValueError("Missing Ideogram API key") raise ValueError("Missing Ideogram API key")
from backend.blocks.ideogram import (
AspectRatio,
ColorPalettePreset,
IdeogramModelBlock,
IdeogramModelName,
MagicPromptOption,
StyleType,
UpscaleOption,
)
name = graph.name name = graph.name
description = f"{name} ({graph.description})" if graph.description else name description = f"{name} ({graph.description})" if graph.description else name

View File

@@ -40,11 +40,10 @@ from backend.api.model import (
UpdateTimezoneRequest, UpdateTimezoneRequest,
UploadFileResponse, UploadFileResponse,
) )
from backend.blocks import get_block, get_blocks
from backend.data import execution as execution_db from backend.data import execution as execution_db
from backend.data import graph as graph_db from backend.data import graph as graph_db
from backend.data.auth import api_key as api_key_db from backend.data.auth import api_key as api_key_db
from backend.data.block import BlockInput, CompletedBlockOutput from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
from backend.data.credit import ( from backend.data.credit import (
AutoTopUpConfig, AutoTopUpConfig,
RefundRequest, RefundRequest,

View File

@@ -26,7 +26,6 @@ import backend.api.features.executions.review.routes
import backend.api.features.library.db import backend.api.features.library.db
import backend.api.features.library.model import backend.api.features.library.model
import backend.api.features.library.routes import backend.api.features.library.routes
import backend.api.features.mcp.routes as mcp_routes
import backend.api.features.oauth import backend.api.features.oauth
import backend.api.features.otto.routes import backend.api.features.otto.routes
import backend.api.features.postmark.postmark import backend.api.features.postmark.postmark
@@ -344,11 +343,6 @@ app.include_router(
tags=["workspace"], tags=["workspace"],
prefix="/api/workspace", prefix="/api/workspace",
) )
app.include_router(
mcp_routes.router,
tags=["v2", "mcp"],
prefix="/api/mcp",
)
app.include_router( app.include_router(
backend.api.features.oauth.router, backend.api.features.oauth.router,
tags=["oauth"], tags=["oauth"],

View File

@@ -3,19 +3,22 @@ import logging
import os import os
import re import re
from pathlib import Path from pathlib import Path
from typing import Sequence, Type, TypeVar from typing import TYPE_CHECKING, TypeVar
from backend.blocks._base import AnyBlockSchema, BlockType
from backend.util.cache import cached from backend.util.cache import cached
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.block import Block
T = TypeVar("T") T = TypeVar("T")
@cached(ttl_seconds=3600) @cached(ttl_seconds=3600)
def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]: def load_all_blocks() -> dict[str, type["Block"]]:
from backend.blocks._base import Block from backend.data.block import Block
from backend.util.settings import Config from backend.util.settings import Config
# Check if example blocks should be loaded from settings # Check if example blocks should be loaded from settings
@@ -47,8 +50,8 @@ def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
importlib.import_module(f".{module}", package=__name__) importlib.import_module(f".{module}", package=__name__)
# Load all Block instances from the available modules # Load all Block instances from the available modules
available_blocks: dict[str, type["AnyBlockSchema"]] = {} available_blocks: dict[str, type["Block"]] = {}
for block_cls in _all_subclasses(Block): for block_cls in all_subclasses(Block):
class_name = block_cls.__name__ class_name = block_cls.__name__
if class_name.endswith("Base"): if class_name.endswith("Base"):
@@ -61,7 +64,7 @@ def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
"please name the class with 'Base' at the end" "please name the class with 'Base' at the end"
) )
block = block_cls() # pyright: ignore[reportAbstractUsage] block = block_cls.create()
if not isinstance(block.id, str) or len(block.id) != 36: if not isinstance(block.id, str) or len(block.id) != 36:
raise ValueError( raise ValueError(
@@ -102,7 +105,7 @@ def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
available_blocks[block.id] = block_cls available_blocks[block.id] = block_cls
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets # Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
from ._utils import is_block_auth_configured from backend.data.block import is_block_auth_configured
filtered_blocks = {} filtered_blocks = {}
for block_id, block_cls in available_blocks.items(): for block_id, block_cls in available_blocks.items():
@@ -112,48 +115,11 @@ def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
return filtered_blocks return filtered_blocks
def _all_subclasses(cls: type[T]) -> list[type[T]]: __all__ = ["load_all_blocks"]
def all_subclasses(cls: type[T]) -> list[type[T]]:
subclasses = cls.__subclasses__() subclasses = cls.__subclasses__()
for subclass in subclasses: for subclass in subclasses:
subclasses += _all_subclasses(subclass) subclasses += all_subclasses(subclass)
return subclasses return subclasses
# ============== Block access helper functions ============== #
def get_blocks() -> dict[str, Type["AnyBlockSchema"]]:
return load_all_blocks()
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
def get_block(block_id: str) -> "AnyBlockSchema | None":
cls = get_blocks().get(block_id)
return cls() if cls else None
@cached(ttl_seconds=3600)
def get_webhook_block_ids() -> Sequence[str]:
return [
id
for id, B in get_blocks().items()
if B().block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
]
@cached(ttl_seconds=3600)
def get_io_block_ids() -> Sequence[str]:
return [
id
for id, B in get_blocks().items()
if B().block_type in (BlockType.INPUT, BlockType.OUTPUT)
]
@cached(ttl_seconds=3600)
def get_human_in_the_loop_block_ids() -> Sequence[str]:
return [
id
for id, B in get_blocks().items()
if B().block_type == BlockType.HUMAN_IN_THE_LOOP
]

View File

@@ -1,740 +0,0 @@
import inspect
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Generic,
Optional,
Type,
TypeAlias,
TypeVar,
cast,
get_origin,
)
import jsonref
import jsonschema
from pydantic import BaseModel
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
from backend.data.model import (
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
SchemaField,
is_credentials_field_name,
)
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.exceptions import (
BlockError,
BlockExecutionError,
BlockInputError,
BlockOutputError,
BlockUnknownError,
)
from backend.util.settings import Config
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
from backend.data.model import ContributorDetails, NodeExecutionStats
from ..data.graph import Link
app_config = Config()
BlockTestOutput = BlockOutputEntry | tuple[str, Callable[[Any], bool]]
class BlockType(Enum):
STANDARD = "Standard"
INPUT = "Input"
OUTPUT = "Output"
NOTE = "Note"
WEBHOOK = "Webhook"
WEBHOOK_MANUAL = "Webhook (manual)"
AGENT = "Agent"
AI = "AI"
AYRSHARE = "Ayrshare"
HUMAN_IN_THE_LOOP = "Human In The Loop"
MCP_TOOL = "MCP Tool"
class BlockCategory(Enum):
AI = "Block that leverages AI to perform a task."
SOCIAL = "Block that interacts with social media platforms."
TEXT = "Block that processes text data."
SEARCH = "Block that searches or extracts information from the internet."
BASIC = "Block that performs basic operations."
INPUT = "Block that interacts with input of the graph."
OUTPUT = "Block that interacts with output of the graph."
LOGIC = "Programming logic to control the flow of your agent"
COMMUNICATION = "Block that interacts with communication platforms."
DEVELOPER_TOOLS = "Developer tools such as GitHub blocks."
DATA = "Block that interacts with structured data."
HARDWARE = "Block that interacts with hardware."
AGENT = "Block that interacts with other agents."
CRM = "Block that interacts with CRM services."
SAFETY = (
"Block that provides AI safety mechanisms such as detecting harmful content"
)
PRODUCTIVITY = "Block that helps with productivity"
ISSUE_TRACKING = "Block that helps with issue tracking"
MULTIMEDIA = "Block that interacts with multimedia content"
MARKETING = "Block that helps with marketing"
def dict(self) -> dict[str, str]:
return {"category": self.name, "description": self.value}
class BlockCostType(str, Enum):
RUN = "run" # cost X credits per run
BYTE = "byte" # cost X credits per byte
SECOND = "second" # cost X credits per second
class BlockCost(BaseModel):
cost_amount: int
cost_filter: BlockInput
cost_type: BlockCostType
def __init__(
self,
cost_amount: int,
cost_type: BlockCostType = BlockCostType.RUN,
cost_filter: Optional[BlockInput] = None,
**data: Any,
) -> None:
super().__init__(
cost_amount=cost_amount,
cost_filter=cost_filter or {},
cost_type=cost_type,
**data,
)
class BlockInfo(BaseModel):
id: str
name: str
inputSchema: dict[str, Any]
outputSchema: dict[str, Any]
costs: list[BlockCost]
description: str
categories: list[dict[str, str]]
contributors: list[dict[str, Any]]
staticOutput: bool
uiType: str
class BlockSchema(BaseModel):
cached_jsonschema: ClassVar[dict[str, Any]]
@classmethod
def jsonschema(cls) -> dict[str, Any]:
if cls.cached_jsonschema:
return cls.cached_jsonschema
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
def ref_to_dict(obj):
if isinstance(obj, dict):
# OpenAPI <3.1 does not support sibling fields that has a $ref key
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
keys = {"allOf", "anyOf", "oneOf"}
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
if one_key:
obj.update(obj[one_key][0])
return {
key: ref_to_dict(value)
for key, value in obj.items()
if not key.startswith("$") and key != one_key
}
elif isinstance(obj, list):
return [ref_to_dict(item) for item in obj]
return obj
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
return cls.cached_jsonschema
@classmethod
def validate_data(cls, data: BlockInput) -> str | None:
return json.validate_with_jsonschema(
schema=cls.jsonschema(),
data={k: v for k, v in data.items() if v is not None},
)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return cls.validate_data(data)
@classmethod
def get_field_schema(cls, field_name: str) -> dict[str, Any]:
model_schema = cls.jsonschema().get("properties", {})
if not model_schema:
raise ValueError(f"Invalid model schema {cls}")
property_schema = model_schema.get(field_name)
if not property_schema:
raise ValueError(f"Invalid property name {field_name}")
return property_schema
@classmethod
def validate_field(cls, field_name: str, data: BlockInput) -> str | None:
"""
Validate the data against a specific property (one of the input/output name).
Returns the validation error message if the data does not match the schema.
"""
try:
property_schema = cls.get_field_schema(field_name)
jsonschema.validate(json.to_dict(data), property_schema)
return None
except jsonschema.ValidationError as e:
return str(e)
@classmethod
def get_fields(cls) -> set[str]:
return set(cls.model_fields.keys())
@classmethod
def get_required_fields(cls) -> set[str]:
return {
field
for field, field_info in cls.model_fields.items()
if field_info.is_required()
}
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""Validates the schema definition. Rules:
- Fields with annotation `CredentialsMetaInput` MUST be
named `credentials` or `*_credentials`
- Fields named `credentials` or `*_credentials` MUST be
of type `CredentialsMetaInput`
"""
super().__pydantic_init_subclass__(**kwargs)
# Reset cached JSON schema to prevent inheriting it from parent class
cls.cached_jsonschema = {}
credentials_fields = cls.get_credentials_fields()
for field_name in cls.get_fields():
if is_credentials_field_name(field_name):
if field_name not in credentials_fields:
raise TypeError(
f"Credentials field '{field_name}' on {cls.__qualname__} "
f"is not of type {CredentialsMetaInput.__name__}"
)
CredentialsMetaInput.validate_credentials_field_schema(
cls.get_field_schema(field_name), field_name
)
elif field_name in credentials_fields:
raise KeyError(
f"Credentials field '{field_name}' on {cls.__qualname__} "
"has invalid name: must be 'credentials' or *_credentials"
)
@classmethod
def get_credentials_fields(cls) -> dict[str, type[CredentialsMetaInput]]:
return {
field_name: info.annotation
for field_name, info in cls.model_fields.items()
if (
inspect.isclass(info.annotation)
and issubclass(
get_origin(info.annotation) or info.annotation,
CredentialsMetaInput,
)
)
}
@classmethod
def get_auto_credentials_fields(cls) -> dict[str, dict[str, Any]]:
"""
Get fields that have auto_credentials metadata (e.g., GoogleDriveFileInput).
Returns a dict mapping kwarg_name -> {field_name, auto_credentials_config}
Raises:
ValueError: If multiple fields have the same kwarg_name, as this would
cause silent overwriting and only the last field would be processed.
"""
result: dict[str, dict[str, Any]] = {}
schema = cls.jsonschema()
properties = schema.get("properties", {})
for field_name, field_schema in properties.items():
auto_creds = field_schema.get("auto_credentials")
if auto_creds:
kwarg_name = auto_creds.get("kwarg_name", "credentials")
if kwarg_name in result:
raise ValueError(
f"Duplicate auto_credentials kwarg_name '{kwarg_name}' "
f"in fields '{result[kwarg_name]['field_name']}' and "
f"'{field_name}' on {cls.__qualname__}"
)
result[kwarg_name] = {
"field_name": field_name,
"config": auto_creds,
}
return result
@classmethod
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
result = {}
# Regular credentials fields
for field_name in cls.get_credentials_fields().keys():
result[field_name] = CredentialsFieldInfo.model_validate(
cls.get_field_schema(field_name), by_alias=True
)
# Auto-generated credentials fields (from GoogleDriveFileInput etc.)
for kwarg_name, info in cls.get_auto_credentials_fields().items():
config = info["config"]
# Build a schema-like dict that CredentialsFieldInfo can parse
auto_schema = {
"credentials_provider": [config.get("provider", "google")],
"credentials_types": [config.get("type", "oauth2")],
"credentials_scopes": config.get("scopes"),
}
result[kwarg_name] = CredentialsFieldInfo.model_validate(
auto_schema, by_alias=True
)
return result
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data # Return as is, by default.
@classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
input_fields_from_nodes = {link.sink_name for link in links}
return input_fields_from_nodes - set(data)
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
return cls.get_required_fields() - set(data)
class BlockSchemaInput(BlockSchema):
"""
Base schema class for block inputs.
All block input schemas should extend this class for consistency.
"""
pass
class BlockSchemaOutput(BlockSchema):
"""
Base schema class for block outputs that includes a standard error field.
All block output schemas should extend this class to ensure consistent error handling.
"""
error: str = SchemaField(
description="Error message if the operation failed", default=""
)
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchemaInput)
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchemaOutput)
class EmptyInputSchema(BlockSchemaInput):
pass
class EmptyOutputSchema(BlockSchemaOutput):
pass
# For backward compatibility - will be deprecated
EmptySchema = EmptyOutputSchema
# --8<-- [start:BlockWebhookConfig]
class BlockManualWebhookConfig(BaseModel):
"""
Configuration model for webhook-triggered blocks on which
the user has to manually set up the webhook at the provider.
"""
provider: ProviderName
"""The service provider that the webhook connects to"""
webhook_type: str
"""
Identifier for the webhook type. E.g. GitHub has repo and organization level hooks.
Only for use in the corresponding `WebhooksManager`.
"""
event_filter_input: str = ""
"""
Name of the block's event filter input.
Leave empty if the corresponding webhook doesn't have distinct event/payload types.
"""
event_format: str = "{event}"
"""
Template string for the event(s) that a block instance subscribes to.
Applied individually to each event selected in the event filter input.
Example: `"pull_request.{event}"` -> `"pull_request.opened"`
"""
class BlockWebhookConfig(BlockManualWebhookConfig):
"""
Configuration model for webhook-triggered blocks for which
the webhook can be automatically set up through the provider's API.
"""
resource_format: str
"""
Template string for the resource that a block instance subscribes to.
Fields will be filled from the block's inputs (except `payload`).
Example: `f"{repo}/pull_requests"` (note: not how it's actually implemented)
Only for use in the corresponding `WebhooksManager`.
"""
# --8<-- [end:BlockWebhookConfig]
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
def __init__(
self,
id: str = "",
description: str = "",
contributors: list["ContributorDetails"] = [],
categories: set[BlockCategory] | None = None,
input_schema: Type[BlockSchemaInputType] = EmptyInputSchema,
output_schema: Type[BlockSchemaOutputType] = EmptyOutputSchema,
test_input: BlockInput | list[BlockInput] | None = None,
test_output: BlockTestOutput | list[BlockTestOutput] | None = None,
test_mock: dict[str, Any] | None = None,
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
disabled: bool = False,
static_output: bool = False,
block_type: BlockType = BlockType.STANDARD,
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
is_sensitive_action: bool = False,
):
"""
Initialize the block with the given schema.
Args:
id: The unique identifier for the block, this value will be persisted in the
DB. So it should be a unique and constant across the application run.
Use the UUID format for the ID.
description: The description of the block, explaining what the block does.
contributors: The list of contributors who contributed to the block.
input_schema: The schema, defined as a Pydantic model, for the input data.
output_schema: The schema, defined as a Pydantic model, for the output data.
test_input: The list or single sample input data for the block, for testing.
test_output: The list or single expected output if the test_input is run.
test_mock: function names on the block implementation to mock on test run.
disabled: If the block is disabled, it will not be available for execution.
static_output: Whether the output links of the block are static by default.
"""
from backend.data.model import NodeExecutionStats
self.id = id
self.input_schema = input_schema
self.output_schema = output_schema
self.test_input = test_input
self.test_output = test_output
self.test_mock = test_mock
self.test_credentials = test_credentials
self.description = description
self.categories = categories or set()
self.contributors = contributors or set()
self.disabled = disabled
self.static_output = static_output
self.block_type = block_type
self.webhook_config = webhook_config
self.is_sensitive_action = is_sensitive_action
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
# Enforce presence of credentials field on auto-setup webhook blocks
if not (cred_fields := self.input_schema.get_credentials_fields()):
raise TypeError(
"credentials field is required on auto-setup webhook blocks"
)
# Disallow multiple credentials inputs on webhook blocks
elif len(cred_fields) > 1:
raise ValueError(
"Multiple credentials inputs not supported on webhook blocks"
)
self.block_type = BlockType.WEBHOOK
else:
self.block_type = BlockType.WEBHOOK_MANUAL
# Enforce shape of webhook event filter, if present
if self.webhook_config.event_filter_input:
event_filter_field = self.input_schema.model_fields[
self.webhook_config.event_filter_input
]
if not (
isinstance(event_filter_field.annotation, type)
and issubclass(event_filter_field.annotation, BaseModel)
and all(
field.annotation is bool
for field in event_filter_field.annotation.model_fields.values()
)
):
raise NotImplementedError(
f"{self.name} has an invalid webhook event selector: "
"field must be a BaseModel and all its fields must be boolean"
)
# Enforce presence of 'payload' input
if "payload" not in self.input_schema.model_fields:
raise TypeError(
f"{self.name} is webhook-triggered but has no 'payload' input"
)
# Disable webhook-triggered block if webhook functionality not available
if not app_config.platform_base_url:
self.disabled = True
@abstractmethod
async def run(self, input_data: BlockSchemaInputType, **kwargs) -> BlockOutput:
"""
Run the block with the given input data.
Args:
input_data: The input data with the structure of input_schema.
Kwargs: Currently 14/02/2025 these include
graph_id: The ID of the graph.
node_id: The ID of the node.
graph_exec_id: The ID of the graph execution.
node_exec_id: The ID of the node execution.
user_id: The ID of the user.
Returns:
A Generator that yields (output_name, output_data).
output_name: One of the output name defined in Block's output_schema.
output_data: The data for the output_name, matching the defined schema.
"""
# --- satisfy the type checker, never executed -------------
if False: # noqa: SIM115
yield "name", "value" # pyright: ignore[reportMissingYield]
raise NotImplementedError(f"{self.name} does not implement the run method.")
async def run_once(
self, input_data: BlockSchemaInputType, output: str, **kwargs
) -> Any:
async for item in self.run(input_data, **kwargs):
name, data = item
if name == output:
return data
raise ValueError(f"{self.name} did not produce any output for {output}")
def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats":
self.execution_stats += stats
return self.execution_stats
@property
def name(self):
return self.__class__.__name__
def to_dict(self):
return {
"id": self.id,
"name": self.name,
"inputSchema": self.input_schema.jsonschema(),
"outputSchema": self.output_schema.jsonschema(),
"description": self.description,
"categories": [category.dict() for category in self.categories],
"contributors": [
contributor.model_dump() for contributor in self.contributors
],
"staticOutput": self.static_output,
"uiType": self.block_type.value,
}
def get_info(self) -> BlockInfo:
from backend.data.credit import get_block_cost
return BlockInfo(
id=self.id,
name=self.name,
inputSchema=self.input_schema.jsonschema(),
outputSchema=self.output_schema.jsonschema(),
costs=get_block_cost(self),
description=self.description,
categories=[category.dict() for category in self.categories],
contributors=[
contributor.model_dump() for contributor in self.contributors
],
staticOutput=self.static_output,
uiType=self.block_type.value,
)
async def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
try:
async for output_name, output_data in self._execute(input_data, **kwargs):
yield output_name, output_data
except Exception as ex:
if isinstance(ex, BlockError):
raise ex
else:
raise (
BlockExecutionError
if isinstance(ex, ValueError)
else BlockUnknownError
)(
message=str(ex),
block_name=self.name,
block_id=self.id,
) from ex
async def is_block_exec_need_review(
self,
input_data: BlockInput,
*,
user_id: str,
node_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: "ExecutionContext",
**kwargs,
) -> tuple[bool, BlockInput]:
"""
Check if this block execution needs human review and handle the review process.
Returns:
Tuple of (should_pause, input_data_to_use)
- should_pause: True if execution should be paused for review
- input_data_to_use: The input data to use (may be modified by reviewer)
"""
if not (
self.is_sensitive_action and execution_context.sensitive_action_safe_mode
):
return False, input_data
from backend.blocks.helpers.review import HITLReviewHelper
# Handle the review request and get decision
decision = await HITLReviewHelper.handle_review_decision(
input_data=input_data,
user_id=user_id,
node_id=node_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
block_name=self.name,
editable=True,
)
if decision is None:
# We're awaiting review - pause execution
return True, input_data
if not decision.should_proceed:
# Review was rejected, raise an error to stop execution
raise BlockExecutionError(
message=f"Block execution rejected by reviewer: {decision.message}",
block_name=self.name,
block_id=self.id,
)
# Review was approved - use the potentially modified data
# ReviewResult.data must be a dict for block inputs
reviewed_data = decision.review_result.data
if not isinstance(reviewed_data, dict):
raise BlockExecutionError(
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
block_name=self.name,
block_id=self.id,
)
return False, reviewed_data
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
# Check for review requirement only if running within a graph execution context
# Direct block execution (e.g., from chat) skips the review process
has_graph_context = all(
key in kwargs
for key in (
"node_exec_id",
"graph_exec_id",
"graph_id",
"execution_context",
)
)
if has_graph_context:
should_pause, input_data = await self.is_block_exec_need_review(
input_data, **kwargs
)
if should_pause:
return
# Validate the input data (original or reviewer-modified) once
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
block_id=self.id,
)
# Use the validated input data
async for output_name, output_data in self.run(
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
**kwargs,
):
if output_name == "error":
raise BlockExecutionError(
message=output_data, block_name=self.name, block_id=self.id
)
if self.block_type == BlockType.STANDARD and (
error := self.output_schema.validate_field(output_name, output_data)
):
raise BlockOutputError(
message=f"Block produced an invalid output data: {error}",
block_name=self.name,
block_id=self.id,
)
yield output_name, output_data
def is_triggered_by_event_type(
self, trigger_config: dict[str, Any], event_type: str
) -> bool:
if not self.webhook_config:
raise TypeError("This method can't be used on non-trigger blocks")
if not self.webhook_config.event_filter_input:
return True
event_filter = trigger_config.get(self.webhook_config.event_filter_input)
if not event_filter:
raise ValueError("Event filter is not configured on trigger")
return event_type in [
self.webhook_config.event_format.format(event=k)
for k in event_filter
if event_filter[k] is True
]
# Type alias for any block with standard input/output schemas
AnyBlockSchema: TypeAlias = Block[BlockSchemaInput, BlockSchemaOutput]

View File

@@ -1,122 +0,0 @@
import logging
import os
from backend.integrations.providers import ProviderName
from ._base import AnyBlockSchema
logger = logging.getLogger(__name__)
def is_block_auth_configured(
block_cls: type[AnyBlockSchema],
) -> bool:
"""
Check if a block has a valid authentication method configured at runtime.
For example if a block is an OAuth-only block and there env vars are not set,
do not show it in the UI.
"""
from backend.sdk.registry import AutoRegistry
# Create an instance to access input_schema
try:
block = block_cls()
except Exception as e:
# If we can't create a block instance, assume it's not OAuth-only
logger.error(f"Error creating block instance for {block_cls.__name__}: {e}")
return True
logger.debug(
f"Checking if block {block_cls.__name__} has a valid provider configured"
)
# Get all credential inputs from input schema
credential_inputs = block.input_schema.get_credentials_fields_info()
required_inputs = block.input_schema.get_required_fields()
if not credential_inputs:
logger.debug(
f"Block {block_cls.__name__} has no credential inputs - Treating as valid"
)
return True
# Check credential inputs
if len(required_inputs.intersection(credential_inputs.keys())) == 0:
logger.debug(
f"Block {block_cls.__name__} has only optional credential inputs"
" - will work without credentials configured"
)
# Check if the credential inputs for this block are correctly configured
for field_name, field_info in credential_inputs.items():
provider_names = field_info.provider
if not provider_names:
logger.warning(
f"Block {block_cls.__name__} "
f"has credential input '{field_name}' with no provider options"
" - Disabling"
)
return False
# If a field has multiple possible providers, each one needs to be usable to
# prevent breaking the UX
for _provider_name in provider_names:
provider_name = _provider_name.value
if provider_name in ProviderName.__members__.values():
logger.debug(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' is part of the legacy provider system"
" - Treating as valid"
)
break
provider = AutoRegistry.get_provider(provider_name)
if not provider:
logger.warning(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"refers to unknown provider '{provider_name}' - Disabling"
)
return False
# Check the provider's supported auth types
if field_info.supported_types != provider.supported_auth_types:
logger.warning(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"has mismatched supported auth types (field <> Provider): "
f"{field_info.supported_types} != {provider.supported_auth_types}"
)
if not (supported_auth_types := provider.supported_auth_types):
# No auth methods are been configured for this provider
logger.warning(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' "
"has no authentication methods configured - Disabling"
)
return False
# Check if provider supports OAuth
if "oauth2" in supported_auth_types:
# Check if OAuth environment variables are set
if (oauth_config := provider.oauth_config) and bool(
os.getenv(oauth_config.client_id_env_var)
and os.getenv(oauth_config.client_secret_env_var)
):
logger.debug(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' is configured for OAuth"
)
else:
logger.error(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' "
"is missing OAuth client ID or secret - Disabling"
)
return False
logger.debug(
f"Block {block_cls.__name__} credential input '{field_name}' is valid; "
f"supported credential types: {', '.join(field_info.supported_types)}"
)
return True

View File

@@ -1,7 +1,7 @@
import logging import logging
from typing import TYPE_CHECKING, Any, Optional from typing import Any, Optional
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockInput, BlockInput,
@@ -9,15 +9,13 @@ from backend.blocks._base import (
BlockSchema, BlockSchema,
BlockSchemaInput, BlockSchemaInput,
BlockType, BlockType,
get_block,
) )
from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks
from backend.data.model import NodeExecutionStats, SchemaField from backend.data.model import NodeExecutionStats, SchemaField
from backend.util.json import validate_with_jsonschema from backend.util.json import validate_with_jsonschema
from backend.util.retry import func_retry from backend.util.retry import func_retry
if TYPE_CHECKING:
from backend.executor.utils import LogMetadata
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@@ -126,10 +124,9 @@ class AgentExecutorBlock(Block):
graph_version: int, graph_version: int,
graph_exec_id: str, graph_exec_id: str,
user_id: str, user_id: str,
logger: "LogMetadata", logger,
) -> BlockOutput: ) -> BlockOutput:
from backend.blocks import get_block
from backend.data.execution import ExecutionEventType from backend.data.execution import ExecutionEventType
from backend.executor import utils as execution_utils from backend.executor import utils as execution_utils
@@ -201,7 +198,7 @@ class AgentExecutorBlock(Block):
self, self,
graph_exec_id: str, graph_exec_id: str,
user_id: str, user_id: str,
logger: "LogMetadata", logger,
) -> None: ) -> None:
from backend.executor import utils as execution_utils from backend.executor import utils as execution_utils

View File

@@ -1,11 +1,5 @@
from typing import Any from typing import Any
from backend.blocks._base import (
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.llm import ( from backend.blocks.llm import (
DEFAULT_LLM_MODEL, DEFAULT_LLM_MODEL,
TEST_CREDENTIALS, TEST_CREDENTIALS,
@@ -17,6 +11,12 @@ from backend.blocks.llm import (
LLMResponse, LLMResponse,
llm_call, llm_call,
) )
from backend.data.block import (
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField

View File

@@ -6,7 +6,7 @@ from pydantic import SecretStr
from replicate.client import Client as ReplicateClient from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput from replicate.helpers import FileOutput
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -5,12 +5,7 @@ from pydantic import SecretStr
from replicate.client import Client as ReplicateClient from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput from replicate.helpers import FileOutput
from backend.blocks._base import ( from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
Block,
BlockCategory,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
from backend.data.model import ( from backend.data.model import (
APIKeyCredentials, APIKeyCredentials,

View File

@@ -6,7 +6,7 @@ from typing import Literal
from pydantic import SecretStr from pydantic import SecretStr
from replicate.client import Client as ReplicateClient from replicate.client import Client as ReplicateClient
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -6,7 +6,7 @@ from typing import Literal
from pydantic import SecretStr from pydantic import SecretStr
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,10 +1,3 @@
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.apollo._api import ApolloClient from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import ( from backend.blocks.apollo._auth import (
TEST_CREDENTIALS, TEST_CREDENTIALS,
@@ -17,6 +10,13 @@ from backend.blocks.apollo.models import (
PrimaryPhone, PrimaryPhone,
SearchOrganizationsRequest, SearchOrganizationsRequest,
) )
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import CredentialsField, SchemaField from backend.data.model import CredentialsField, SchemaField

View File

@@ -1,12 +1,5 @@
import asyncio import asyncio
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.apollo._api import ApolloClient from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import ( from backend.blocks.apollo._auth import (
TEST_CREDENTIALS, TEST_CREDENTIALS,
@@ -21,6 +14,13 @@ from backend.blocks.apollo.models import (
SearchPeopleRequest, SearchPeopleRequest,
SenorityLevels, SenorityLevels,
) )
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import CredentialsField, SchemaField from backend.data.model import CredentialsField, SchemaField

View File

@@ -1,10 +1,3 @@
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.apollo._api import ApolloClient from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import ( from backend.blocks.apollo._auth import (
TEST_CREDENTIALS, TEST_CREDENTIALS,
@@ -13,6 +6,13 @@ from backend.blocks.apollo._auth import (
ApolloCredentialsInput, ApolloCredentialsInput,
) )
from backend.blocks.apollo.models import Contact, EnrichPersonRequest from backend.blocks.apollo.models import Contact, EnrichPersonRequest
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import CredentialsField, SchemaField from backend.data.model import CredentialsField, SchemaField

View File

@@ -3,7 +3,7 @@ from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from backend.blocks._base import BlockSchemaInput from backend.data.block import BlockSchemaInput
from backend.data.model import SchemaField, UserIntegrations from backend.data.model import SchemaField, UserIntegrations
from backend.integrations.ayrshare import AyrshareClient from backend.integrations.ayrshare import AyrshareClient
from backend.util.clients import get_database_manager_async_client from backend.util.clients import get_database_manager_async_client

View File

@@ -1,7 +1,7 @@
import enum import enum
from typing import Any from typing import Any
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -2,7 +2,7 @@ import os
import re import re
from typing import Type from typing import Type
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,7 +1,7 @@
from enum import Enum from enum import Enum
from typing import Any from typing import Any
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,12 +1,12 @@
import json import json
import shlex import shlex
import uuid import uuid
from typing import TYPE_CHECKING, Literal, Optional from typing import Literal, Optional
from e2b import AsyncSandbox as BaseAsyncSandbox from e2b import AsyncSandbox as BaseAsyncSandbox
from pydantic import SecretStr from pydantic import BaseModel, SecretStr
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
@@ -20,13 +20,6 @@ from backend.data.model import (
SchemaField, SchemaField,
) )
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.util.sandbox_files import (
SandboxFileOutput,
extract_and_store_sandbox_files,
)
if TYPE_CHECKING:
from backend.executor.utils import ExecutionContext
class ClaudeCodeExecutionError(Exception): class ClaudeCodeExecutionError(Exception):
@@ -181,15 +174,22 @@ class ClaudeCodeBlock(Block):
advanced=True, advanced=True,
) )
class FileOutput(BaseModel):
"""A file extracted from the sandbox."""
path: str
relative_path: str # Path relative to working directory (for GitHub, etc.)
name: str
content: str
class Output(BlockSchemaOutput): class Output(BlockSchemaOutput):
response: str = SchemaField( response: str = SchemaField(
description="The output/response from Claude Code execution" description="The output/response from Claude Code execution"
) )
files: list[SandboxFileOutput] = SchemaField( files: list["ClaudeCodeBlock.FileOutput"] = SchemaField(
description=( description=(
"List of text files created/modified by Claude Code during this execution. " "List of text files created/modified by Claude Code during this execution. "
"Each file has 'path', 'relative_path', 'name', 'content', and 'workspace_ref' fields. " "Each file has 'path', 'relative_path', 'name', and 'content' fields."
"workspace_ref contains a workspace:// URI if the file was stored to workspace."
) )
) )
conversation_history: str = SchemaField( conversation_history: str = SchemaField(
@@ -252,7 +252,6 @@ class ClaudeCodeBlock(Block):
"relative_path": "index.html", "relative_path": "index.html",
"name": "index.html", "name": "index.html",
"content": "<html>Hello World</html>", "content": "<html>Hello World</html>",
"workspace_ref": None,
} }
], ],
), ),
@@ -268,12 +267,11 @@ class ClaudeCodeBlock(Block):
"execute_claude_code": lambda *args, **kwargs: ( "execute_claude_code": lambda *args, **kwargs: (
"Created index.html with hello world content", # response "Created index.html with hello world content", # response
[ [
SandboxFileOutput( ClaudeCodeBlock.FileOutput(
path="/home/user/index.html", path="/home/user/index.html",
relative_path="index.html", relative_path="index.html",
name="index.html", name="index.html",
content="<html>Hello World</html>", content="<html>Hello World</html>",
workspace_ref=None,
) )
], # files ], # files
"User: Create a hello world HTML file\n" "User: Create a hello world HTML file\n"
@@ -296,8 +294,7 @@ class ClaudeCodeBlock(Block):
existing_sandbox_id: str, existing_sandbox_id: str,
conversation_history: str, conversation_history: str,
dispose_sandbox: bool, dispose_sandbox: bool,
execution_context: "ExecutionContext", ) -> tuple[str, list["ClaudeCodeBlock.FileOutput"], str, str, str]:
) -> tuple[str, list[SandboxFileOutput], str, str, str]:
""" """
Execute Claude Code in an E2B sandbox. Execute Claude Code in an E2B sandbox.
@@ -452,18 +449,14 @@ class ClaudeCodeBlock(Block):
else: else:
new_conversation_history = turn_entry new_conversation_history = turn_entry
# Extract files created/modified during this run and store to workspace # Extract files created/modified during this run
sandbox_files = await extract_and_store_sandbox_files( files = await self._extract_files(
sandbox=sandbox, sandbox, working_directory, start_timestamp
working_directory=working_directory,
execution_context=execution_context,
since_timestamp=start_timestamp,
text_only=True,
) )
return ( return (
response, response,
sandbox_files, # Already SandboxFileOutput objects files,
new_conversation_history, new_conversation_history,
current_session_id, current_session_id,
sandbox_id, sandbox_id,
@@ -478,6 +471,140 @@ class ClaudeCodeBlock(Block):
if dispose_sandbox and sandbox: if dispose_sandbox and sandbox:
await sandbox.kill() await sandbox.kill()
async def _extract_files(
self,
sandbox: BaseAsyncSandbox,
working_directory: str,
since_timestamp: str | None = None,
) -> list["ClaudeCodeBlock.FileOutput"]:
"""
Extract text files created/modified during this Claude Code execution.
Args:
sandbox: The E2B sandbox instance
working_directory: Directory to search for files
since_timestamp: ISO timestamp - only return files modified after this time
Returns:
List of FileOutput objects with path, relative_path, name, and content
"""
files: list[ClaudeCodeBlock.FileOutput] = []
# Text file extensions we can safely read as text
text_extensions = {
".txt",
".md",
".html",
".htm",
".css",
".js",
".ts",
".jsx",
".tsx",
".json",
".xml",
".yaml",
".yml",
".toml",
".ini",
".cfg",
".conf",
".py",
".rb",
".php",
".java",
".c",
".cpp",
".h",
".hpp",
".cs",
".go",
".rs",
".swift",
".kt",
".scala",
".sh",
".bash",
".zsh",
".sql",
".graphql",
".env",
".gitignore",
".dockerfile",
"Dockerfile",
".vue",
".svelte",
".astro",
".mdx",
".rst",
".tex",
".csv",
".log",
}
try:
# List files recursively using find command
# Exclude node_modules and .git directories, but allow hidden files
# like .env and .gitignore (they're filtered by text_extensions later)
# Filter by timestamp to only get files created/modified during this run
safe_working_dir = shlex.quote(working_directory)
timestamp_filter = ""
if since_timestamp:
timestamp_filter = f"-newermt {shlex.quote(since_timestamp)} "
find_result = await sandbox.commands.run(
f"find {safe_working_dir} -type f "
f"{timestamp_filter}"
f"-not -path '*/node_modules/*' "
f"-not -path '*/.git/*' "
f"2>/dev/null"
)
if find_result.stdout:
for file_path in find_result.stdout.strip().split("\n"):
if not file_path:
continue
# Check if it's a text file we can read
is_text = any(
file_path.endswith(ext) for ext in text_extensions
) or file_path.endswith("Dockerfile")
if is_text:
try:
content = await sandbox.files.read(file_path)
# Handle bytes or string
if isinstance(content, bytes):
content = content.decode("utf-8", errors="replace")
# Extract filename from path
file_name = file_path.split("/")[-1]
# Calculate relative path by stripping working directory
relative_path = file_path
if file_path.startswith(working_directory):
relative_path = file_path[len(working_directory) :]
# Remove leading slash if present
if relative_path.startswith("/"):
relative_path = relative_path[1:]
files.append(
ClaudeCodeBlock.FileOutput(
path=file_path,
relative_path=relative_path,
name=file_name,
content=content,
)
)
except Exception:
# Skip files that can't be read
pass
except Exception:
# If file extraction fails, return empty results
pass
return files
def _escape_prompt(self, prompt: str) -> str: def _escape_prompt(self, prompt: str) -> str:
"""Escape the prompt for safe shell execution.""" """Escape the prompt for safe shell execution."""
# Use single quotes and escape any single quotes in the prompt # Use single quotes and escape any single quotes in the prompt
@@ -490,7 +617,6 @@ class ClaudeCodeBlock(Block):
*, *,
e2b_credentials: APIKeyCredentials, e2b_credentials: APIKeyCredentials,
anthropic_credentials: APIKeyCredentials, anthropic_credentials: APIKeyCredentials,
execution_context: "ExecutionContext",
**kwargs, **kwargs,
) -> BlockOutput: ) -> BlockOutput:
try: try:
@@ -511,7 +637,6 @@ class ClaudeCodeBlock(Block):
existing_sandbox_id=input_data.sandbox_id, existing_sandbox_id=input_data.sandbox_id,
conversation_history=input_data.conversation_history, conversation_history=input_data.conversation_history,
dispose_sandbox=input_data.dispose_sandbox, dispose_sandbox=input_data.dispose_sandbox,
execution_context=execution_context,
) )
yield "response", response yield "response", response

View File

@@ -1,12 +1,12 @@
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Literal, Optional from typing import Any, Literal, Optional
from e2b_code_interpreter import AsyncSandbox from e2b_code_interpreter import AsyncSandbox
from e2b_code_interpreter import Result as E2BExecutionResult from e2b_code_interpreter import Result as E2BExecutionResult
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
from pydantic import BaseModel, Field, JsonValue, SecretStr from pydantic import BaseModel, Field, JsonValue, SecretStr
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
@@ -20,13 +20,6 @@ from backend.data.model import (
SchemaField, SchemaField,
) )
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.util.sandbox_files import (
SandboxFileOutput,
extract_and_store_sandbox_files,
)
if TYPE_CHECKING:
from backend.executor.utils import ExecutionContext
TEST_CREDENTIALS = APIKeyCredentials( TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef", id="01234567-89ab-cdef-0123-456789abcdef",
@@ -92,9 +85,6 @@ class CodeExecutionResult(MainCodeExecutionResult):
class BaseE2BExecutorMixin: class BaseE2BExecutorMixin:
"""Shared implementation methods for E2B executor blocks.""" """Shared implementation methods for E2B executor blocks."""
# Default working directory in E2B sandboxes
WORKING_DIR = "/home/user"
async def execute_code( async def execute_code(
self, self,
api_key: str, api_key: str,
@@ -105,21 +95,14 @@ class BaseE2BExecutorMixin:
timeout: Optional[int] = None, timeout: Optional[int] = None,
sandbox_id: Optional[str] = None, sandbox_id: Optional[str] = None,
dispose_sandbox: bool = False, dispose_sandbox: bool = False,
execution_context: Optional["ExecutionContext"] = None,
extract_files: bool = False,
): ):
""" """
Unified code execution method that handles all three use cases: Unified code execution method that handles all three use cases:
1. Create new sandbox and execute (ExecuteCodeBlock) 1. Create new sandbox and execute (ExecuteCodeBlock)
2. Create new sandbox, execute, and return sandbox_id (InstantiateCodeSandboxBlock) 2. Create new sandbox, execute, and return sandbox_id (InstantiateCodeSandboxBlock)
3. Connect to existing sandbox and execute (ExecuteCodeStepBlock) 3. Connect to existing sandbox and execute (ExecuteCodeStepBlock)
Args:
extract_files: If True and execution_context provided, extract files
created/modified during execution and store to workspace.
""" # noqa """ # noqa
sandbox = None sandbox = None
files: list[SandboxFileOutput] = []
try: try:
if sandbox_id: if sandbox_id:
# Connect to existing sandbox (ExecuteCodeStepBlock case) # Connect to existing sandbox (ExecuteCodeStepBlock case)
@@ -135,12 +118,6 @@ class BaseE2BExecutorMixin:
for cmd in setup_commands: for cmd in setup_commands:
await sandbox.commands.run(cmd) await sandbox.commands.run(cmd)
# Capture timestamp before execution to scope file extraction
start_timestamp = None
if extract_files:
ts_result = await sandbox.commands.run("date -u +%Y-%m-%dT%H:%M:%S")
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
# Execute the code # Execute the code
execution = await sandbox.run_code( execution = await sandbox.run_code(
code, code,
@@ -156,24 +133,7 @@ class BaseE2BExecutorMixin:
stdout_logs = "".join(execution.logs.stdout) stdout_logs = "".join(execution.logs.stdout)
stderr_logs = "".join(execution.logs.stderr) stderr_logs = "".join(execution.logs.stderr)
# Extract files created/modified during this execution return results, text_output, stdout_logs, stderr_logs, sandbox.sandbox_id
if extract_files and execution_context:
files = await extract_and_store_sandbox_files(
sandbox=sandbox,
working_directory=self.WORKING_DIR,
execution_context=execution_context,
since_timestamp=start_timestamp,
text_only=False, # Include binary files too
)
return (
results,
text_output,
stdout_logs,
stderr_logs,
sandbox.sandbox_id,
files,
)
finally: finally:
# Dispose of sandbox if requested to reduce usage costs # Dispose of sandbox if requested to reduce usage costs
if dispose_sandbox and sandbox: if dispose_sandbox and sandbox:
@@ -278,12 +238,6 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
description="Standard output logs from execution" description="Standard output logs from execution"
) )
stderr_logs: str = SchemaField(description="Standard error logs from execution") stderr_logs: str = SchemaField(description="Standard error logs from execution")
files: list[SandboxFileOutput] = SchemaField(
description=(
"Files created or modified during execution. "
"Each file has path, name, content, and workspace_ref (if stored)."
),
)
def __init__(self): def __init__(self):
super().__init__( super().__init__(
@@ -305,30 +259,23 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
("results", []), ("results", []),
("response", "Hello World"), ("response", "Hello World"),
("stdout_logs", "Hello World\n"), ("stdout_logs", "Hello World\n"),
("files", []),
], ],
test_mock={ test_mock={
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout, dispose_sandbox, execution_context, extract_files: ( # noqa "execute_code": lambda api_key, code, language, template_id, setup_commands, timeout, dispose_sandbox: ( # noqa
[], # results [], # results
"Hello World", # text_output "Hello World", # text_output
"Hello World\n", # stdout_logs "Hello World\n", # stdout_logs
"", # stderr_logs "", # stderr_logs
"sandbox_id", # sandbox_id "sandbox_id", # sandbox_id
[], # files
), ),
}, },
) )
async def run( async def run(
self, self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: "ExecutionContext",
**kwargs,
) -> BlockOutput: ) -> BlockOutput:
try: try:
results, text_output, stdout, stderr, _, files = await self.execute_code( results, text_output, stdout, stderr, _ = await self.execute_code(
api_key=credentials.api_key.get_secret_value(), api_key=credentials.api_key.get_secret_value(),
code=input_data.code, code=input_data.code,
language=input_data.language, language=input_data.language,
@@ -336,8 +283,6 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
setup_commands=input_data.setup_commands, setup_commands=input_data.setup_commands,
timeout=input_data.timeout, timeout=input_data.timeout,
dispose_sandbox=input_data.dispose_sandbox, dispose_sandbox=input_data.dispose_sandbox,
execution_context=execution_context,
extract_files=True,
) )
# Determine result object shape & filter out empty formats # Determine result object shape & filter out empty formats
@@ -351,8 +296,6 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
yield "stdout_logs", stdout yield "stdout_logs", stdout
if stderr: if stderr:
yield "stderr_logs", stderr yield "stderr_logs", stderr
# Always yield files (empty list if none)
yield "files", [f.model_dump() for f in files]
except Exception as e: except Exception as e:
yield "error", str(e) yield "error", str(e)
@@ -450,7 +393,6 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
"Hello World\n", # stdout_logs "Hello World\n", # stdout_logs
"", # stderr_logs "", # stderr_logs
"sandbox_id", # sandbox_id "sandbox_id", # sandbox_id
[], # files
), ),
}, },
) )
@@ -459,7 +401,7 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput: ) -> BlockOutput:
try: try:
_, text_output, stdout, stderr, sandbox_id, _ = await self.execute_code( _, text_output, stdout, stderr, sandbox_id = await self.execute_code(
api_key=credentials.api_key.get_secret_value(), api_key=credentials.api_key.get_secret_value(),
code=input_data.setup_code, code=input_data.setup_code,
language=input_data.language, language=input_data.language,
@@ -558,7 +500,6 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
"Hello World\n", # stdout_logs "Hello World\n", # stdout_logs
"", # stderr_logs "", # stderr_logs
sandbox_id, # sandbox_id sandbox_id, # sandbox_id
[], # files
), ),
}, },
) )
@@ -567,7 +508,7 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput: ) -> BlockOutput:
try: try:
results, text_output, stdout, stderr, _, _ = await self.execute_code( results, text_output, stdout, stderr, _ = await self.execute_code(
api_key=credentials.api_key.get_secret_value(), api_key=credentials.api_key.get_secret_value(),
code=input_data.step_code, code=input_data.step_code,
language=input_data.language, language=input_data.language,

View File

@@ -1,6 +1,6 @@
import re import re
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -6,7 +6,7 @@ from openai import AsyncOpenAI
from openai.types.responses import Response as OpenAIResponse from openai.types.responses import Response as OpenAIResponse
from pydantic import SecretStr from pydantic import SecretStr
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,6 +1,6 @@
from pydantic import BaseModel from pydantic import BaseModel
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockManualWebhookConfig, BlockManualWebhookConfig,

View File

@@ -1,4 +1,4 @@
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,6 +1,6 @@
from typing import Any, List from typing import Any, List
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,6 +1,6 @@
import codecs import codecs
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -8,7 +8,7 @@ from typing import Any, Literal, cast
import discord import discord
from pydantic import SecretStr from pydantic import SecretStr
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -2,7 +2,7 @@
Discord OAuth-based blocks. Discord OAuth-based blocks.
""" """
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -7,7 +7,7 @@ from typing import Literal
from pydantic import BaseModel, ConfigDict, SecretStr from pydantic import BaseModel, ConfigDict, SecretStr
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -2,7 +2,7 @@
import codecs import codecs
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -8,7 +8,7 @@ which provides access to LinkedIn profile data and related information.
import logging import logging
from typing import Optional from typing import Optional
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -3,13 +3,6 @@ import logging
from enum import Enum from enum import Enum
from typing import Any from typing import Any
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.fal._auth import ( from backend.blocks.fal._auth import (
TEST_CREDENTIALS, TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT, TEST_CREDENTIALS_INPUT,
@@ -17,6 +10,13 @@ from backend.blocks.fal._auth import (
FalCredentialsField, FalCredentialsField,
FalCredentialsInput, FalCredentialsInput,
) )
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.file import store_media_file from backend.util.file import store_media_file

View File

@@ -5,7 +5,7 @@ from pydantic import SecretStr
from replicate.client import Client as ReplicateClient from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput from replicate.helpers import FileOutput
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -3,7 +3,7 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -5,7 +5,7 @@ from typing import Optional
from typing_extensions import TypedDict from typing_extensions import TypedDict
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -3,7 +3,7 @@ from urllib.parse import urlparse
from typing_extensions import TypedDict from typing_extensions import TypedDict
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -2,7 +2,7 @@ import re
from typing_extensions import TypedDict from typing_extensions import TypedDict
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -2,7 +2,7 @@ import base64
from typing_extensions import TypedDict from typing_extensions import TypedDict
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -4,7 +4,7 @@ from typing import Any, List, Optional
from typing_extensions import TypedDict from typing_extensions import TypedDict
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -3,7 +3,7 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from pydantic import BaseModel from pydantic import BaseModel
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -8,7 +8,7 @@ from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build from googleapiclient.discovery import build
from pydantic import BaseModel from pydantic import BaseModel
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -7,14 +7,14 @@ from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build from googleapiclient.discovery import build
from gravitas_md2gdocs import to_requests from gravitas_md2gdocs import to_requests
from backend.blocks._base import ( from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
BlockSchemaInput, BlockSchemaInput,
BlockSchemaOutput, BlockSchemaOutput,
) )
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.settings import Settings from backend.util.settings import Settings

View File

@@ -14,7 +14,7 @@ from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build from googleapiclient.discovery import build
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -7,14 +7,14 @@ from enum import Enum
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build from googleapiclient.discovery import build
from backend.blocks._base import ( from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
BlockSchemaInput, BlockSchemaInput,
BlockSchemaOutput, BlockSchemaOutput,
) )
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.settings import Settings from backend.util.settings import Settings

View File

@@ -3,7 +3,7 @@ from typing import Literal
import googlemaps import googlemaps
from pydantic import BaseModel, SecretStr from pydantic import BaseModel, SecretStr
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -9,7 +9,9 @@ from typing import Any, Optional
from prisma.enums import ReviewStatus from prisma.enums import ReviewStatus
from pydantic import BaseModel from pydantic import BaseModel
from backend.data.execution import ExecutionStatus
from backend.data.human_review import ReviewResult from backend.data.human_review import ReviewResult
from backend.executor.manager import async_update_node_execution_status
from backend.util.clients import get_database_manager_async_client from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -41,8 +43,6 @@ class HITLReviewHelper:
@staticmethod @staticmethod
async def update_node_execution_status(**kwargs) -> None: async def update_node_execution_status(**kwargs) -> None:
"""Update the execution status of a node.""" """Update the execution status of a node."""
from backend.executor.manager import async_update_node_execution_status
await async_update_node_execution_status( await async_update_node_execution_status(
db_client=get_database_manager_async_client(), **kwargs db_client=get_database_manager_async_client(), **kwargs
) )
@@ -88,13 +88,12 @@ class HITLReviewHelper:
Raises: Raises:
Exception: If review creation or status update fails Exception: If review creation or status update fails
""" """
from backend.data.execution import ExecutionStatus
# Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode) # Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode)
# are handled by the caller: # are handled by the caller:
# - HITL blocks check human_in_the_loop_safe_mode in their run() method # - HITL blocks check human_in_the_loop_safe_mode in their run() method
# - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review() # - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review()
# This function only handles checking for existing approvals. # This function only handles checking for existing approvals.
# Check if this node has already been approved (normal or auto-approval) # Check if this node has already been approved (normal or auto-approval)
if approval_result := await HITLReviewHelper.check_approval( if approval_result := await HITLReviewHelper.check_approval(
node_exec_id=node_exec_id, node_exec_id=node_exec_id,

View File

@@ -8,7 +8,7 @@ from typing import Literal
import aiofiles import aiofiles
from pydantic import SecretStr from pydantic import SecretStr
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,15 +1,15 @@
from backend.blocks._base import ( from backend.blocks.hubspot._auth import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
BlockSchemaInput, BlockSchemaInput,
BlockSchemaOutput, BlockSchemaOutput,
) )
from backend.blocks.hubspot._auth import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.request import Requests from backend.util.request import Requests

View File

@@ -1,15 +1,15 @@
from backend.blocks._base import ( from backend.blocks.hubspot._auth import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
BlockSchemaInput, BlockSchemaInput,
BlockSchemaOutput, BlockSchemaOutput,
) )
from backend.blocks.hubspot._auth import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.request import Requests from backend.util.request import Requests

View File

@@ -1,17 +1,17 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from backend.blocks._base import ( from backend.blocks.hubspot._auth import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
BlockSchemaInput, BlockSchemaInput,
BlockSchemaOutput, BlockSchemaOutput,
) )
from backend.blocks.hubspot._auth import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.request import Requests from backend.util.request import Requests

View File

@@ -3,7 +3,8 @@ from typing import Any
from prisma.enums import ReviewStatus from prisma.enums import ReviewStatus
from backend.blocks._base import ( from backend.blocks.helpers.review import HITLReviewHelper
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
@@ -11,7 +12,6 @@ from backend.blocks._base import (
BlockSchemaOutput, BlockSchemaOutput,
BlockType, BlockType,
) )
from backend.blocks.helpers.review import HITLReviewHelper
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
from backend.data.human_review import ReviewResult from backend.data.human_review import ReviewResult
from backend.data.model import SchemaField from backend.data.model import SchemaField
@@ -21,71 +21,43 @@ logger = logging.getLogger(__name__)
class HumanInTheLoopBlock(Block): class HumanInTheLoopBlock(Block):
""" """
Pauses execution and waits for human approval or rejection of the data. This block pauses execution and waits for human approval or modification of the data.
When executed, this block creates a pending review entry and sets the node execution When executed, it creates a pending review entry and sets the node execution status
status to REVIEW. The execution remains paused until a human user either approves to REVIEW. The execution will remain paused until a human user either:
or rejects the data. - Approves the data (with or without modifications)
- Rejects the data
**How it works:** This is useful for workflows that require human validation or intervention before
- The input data is presented to a human reviewer proceeding to the next steps.
- The reviewer can approve or reject (and optionally modify the data if editable)
- On approval: the data flows out through the `approved_data` output pin
- On rejection: the data flows out through the `rejected_data` output pin
**Important:** The output pins yield the actual data itself, NOT status strings.
The approval/rejection decision determines WHICH output pin fires, not the value.
You do NOT need to compare the output to "APPROVED" or "REJECTED" - simply connect
downstream blocks to the appropriate output pin for each case.
**Example usage:**
- Connect `approved_data` → next step in your workflow (data was approved)
- Connect `rejected_data` → error handling or notification (data was rejected)
""" """
class Input(BlockSchemaInput): class Input(BlockSchemaInput):
data: Any = SchemaField( data: Any = SchemaField(description="The data to be reviewed by a human user")
description="The data to be reviewed by a human user. "
"This exact data will be passed through to either approved_data or "
"rejected_data output based on the reviewer's decision."
)
name: str = SchemaField( name: str = SchemaField(
description="A descriptive name for what this data represents. " description="A descriptive name for what this data represents",
"This helps the reviewer understand what they are reviewing.",
) )
editable: bool = SchemaField( editable: bool = SchemaField(
description="Whether the human reviewer can edit the data before " description="Whether the human reviewer can edit the data",
"approving or rejecting it",
default=True, default=True,
advanced=True, advanced=True,
) )
class Output(BlockSchemaOutput): class Output(BlockSchemaOutput):
approved_data: Any = SchemaField( approved_data: Any = SchemaField(
description="Outputs the input data when the reviewer APPROVES it. " description="The data when approved (may be modified by reviewer)"
"The value is the actual data itself (not a status string like 'APPROVED'). "
"If the reviewer edited the data, this contains the modified version. "
"Connect downstream blocks here for the 'approved' workflow path."
) )
rejected_data: Any = SchemaField( rejected_data: Any = SchemaField(
description="Outputs the input data when the reviewer REJECTS it. " description="The data when rejected (may be modified by reviewer)"
"The value is the actual data itself (not a status string like 'REJECTED'). "
"If the reviewer edited the data, this contains the modified version. "
"Connect downstream blocks here for the 'rejected' workflow path."
) )
review_message: str = SchemaField( review_message: str = SchemaField(
description="Optional message provided by the reviewer explaining their " description="Any message provided by the reviewer", default=""
"decision. Only outputs when the reviewer provides a message; "
"this pin does not fire if no message was given.",
default="",
) )
def __init__(self): def __init__(self):
super().__init__( super().__init__(
id="8b2a7b3c-6e9d-4a5f-8c1b-2e3f4a5b6c7d", id="8b2a7b3c-6e9d-4a5f-8c1b-2e3f4a5b6c7d",
description="Pause execution for human review. Data flows through " description="Pause execution and wait for human approval or modification of data",
"approved_data or rejected_data output based on the reviewer's decision. "
"Outputs contain the actual data, not status strings.",
categories={BlockCategory.BASIC}, categories={BlockCategory.BASIC},
input_schema=HumanInTheLoopBlock.Input, input_schema=HumanInTheLoopBlock.Input,
output_schema=HumanInTheLoopBlock.Output, output_schema=HumanInTheLoopBlock.Output,

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, Literal, Optional
from pydantic import SecretStr from pydantic import SecretStr
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -2,7 +2,9 @@ import copy
from datetime import date, time from datetime import date, time
from typing import Any, Optional from typing import Any, Optional
from backend.blocks._base import ( # Import for Google Drive file input block
from backend.blocks.google._drive import AttachmentView, GoogleDriveFile
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
@@ -10,9 +12,6 @@ from backend.blocks._base import (
BlockSchemaInput, BlockSchemaInput,
BlockType, BlockType,
) )
# Import for Google Drive file input block
from backend.blocks.google._drive import AttachmentView, GoogleDriveFile
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.file import store_media_file from backend.util.file import store_media_file

View File

@@ -1,6 +1,6 @@
from typing import Any from typing import Any
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,15 +1,15 @@
from backend.blocks._base import ( from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
BlockSchemaInput, BlockSchemaInput,
BlockSchemaOutput, BlockSchemaOutput,
) )
from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.request import Requests from backend.util.request import Requests

View File

@@ -1,15 +1,15 @@
from backend.blocks._base import ( from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
BlockSchemaInput, BlockSchemaInput,
BlockSchemaOutput, BlockSchemaOutput,
) )
from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.request import Requests from backend.util.request import Requests

View File

@@ -3,18 +3,18 @@ from urllib.parse import quote
from typing_extensions import TypedDict from typing_extensions import TypedDict
from backend.blocks._base import ( from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
BlockSchemaInput, BlockSchemaInput,
BlockSchemaOutput, BlockSchemaOutput,
) )
from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.request import Requests from backend.util.request import Requests

View File

@@ -1,12 +1,5 @@
from urllib.parse import quote from urllib.parse import quote
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.jina._auth import ( from backend.blocks.jina._auth import (
TEST_CREDENTIALS, TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT, TEST_CREDENTIALS_INPUT,
@@ -15,6 +8,13 @@ from backend.blocks.jina._auth import (
JinaCredentialsInput, JinaCredentialsInput,
) )
from backend.blocks.search import GetRequest from backend.blocks.search import GetRequest
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError from backend.util.exceptions import BlockExecutionError

View File

@@ -15,7 +15,7 @@ from anthropic.types import ToolParam
from groq import AsyncGroq from groq import AsyncGroq
from pydantic import BaseModel, SecretStr from pydantic import BaseModel, SecretStr
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

Some files were not shown because too many files have changed in this diff Show More